Beispiel #1
0
class SAC(object):
    def __init__(self):

        self.gamma = 0.99
        self.tau = 0.005
        self.alpha = 0.2
        self.lr = 0.003

        self.target_update_interval = 1
        self.device = torch.device("cpu")

        # 8 phases
        self.num_inputs = 8
        self.num_actions = 1
        self.hidden_size = 256

        self.critic = QNetwork(self.num_inputs, self.num_actions,
                               self.hidden_size).to(self.device)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=self.lr)

        self.critic_target = QNetwork(self.num_inputs, self.num_actions,
                                      self.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)
        # Copy the parameters of critic to critic_target

        self.target_entropy = -torch.Tensor([1.0]).to(self.device).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)

        self.alpha_optimizer = Adam([self.log_alpha], lr=self.lr)

        self.policy = GaussianPolicy(self.num_inputs, self.num_actions,
                                     self.hidden_size).to(self.device)
        self.policy_optimizer = Adam(self.policy.parameters(), lr=self.lr)

    def select_action(self, state):
        state = torch.FloatTensor(state).to(self.device)  # TODO
        _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]
        # action is a CUDA tensor, you should do .detach().cpu().numpy(), when
        # you need a numpy

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)
        action_batch = np.expand_dims(action_batch, axis=1)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
        # Unsqueeze: add one dimension to the index

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optimizer.zero_grad()
        # Clear the cumulative grad
        qf_loss.backward()
        # Get grad via backward()
        self.critic_optimizer.step()
        # Update the para via grad

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # automatic_entropy_tuning:
        alpha_loss = -(self.log_alpha *
                       (log_pi + self.target_entropy).detach()).mean()  # TODO

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.alpha = self.log_alpha.exp()
        alpha_tlogs = self.alpha.clone()  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None):
        # Create a dir package in the current location
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        # state_dict() stores the parameters of layers and optimizers which have grad
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))

    def get_alpha(self):
        return self.alpha
Beispiel #2
0
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
        if self.automatic_entropy_tuning is True:
            self.target_entropy = -torch.prod(
                torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

        self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                     args.hidden_size,
                                     action_space).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * min_qf_next_target
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        critic_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        pi, log_pi, _ = self.policy.sample(state_batch)
        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path, device='cpu'):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(
                torch.load(actor_path, map_location=torch.device(device)))
        if critic_path is not None:
            self.critic.load_state_dict(
                torch.load(critic_path, map_location=torch.device(device)))
Beispiel #3
0
class SAC(object):
    def __init__(self, num_inputs, action_space, config):

        self.gamma = config['gamma']
        self.tau = config['tau']
        self.alpha = config['alpha']

        self.policy_type = config['policy']
        self.target_update_interval = config['target_update_interval']
        self.automatic_entropy_tuning = config['automatic_entropy_tuning']

        self.device = torch.device(
            'cuda:' + str(config['cuda'])) if torch.cuda.is_available(
            ) and config['cuda'] >= 0 else torch.device('cpu')

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               config['hidden_size']).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=config['lr'])

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      config['hidden_size']).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=config['lr'])

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         config['hidden_size'],
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=config['lr'])

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self, save_path=None, env_name=None, suffix=None):
        if save_path is None:
            save_path = './models/'

        actor_path = '{}actor_{}_{}'.format(save_path, env_name, suffix)
        critic_path = "{}critic_{}_{}".format(save_path, env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
class soft_actor_critic_agent(object):
    def __init__(self, num_inputs, action_space, \
                 device, hidden_size, seed, lr, gamma, tau, alpha):

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha

        self.device = device
        self.seed = seed
        self.seed = torch.manual_seed(seed)

        torch.cuda.manual_seed(seed)
        #torch.cuda.manual_seed_all(seed)
        #torch.backends.cudnn.deterministic=True

        self.critic = QNetwork(seed, num_inputs, action_space.shape[0],
                               hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=lr)

        self.critic_target = QNetwork(seed, num_inputs, action_space.shape[0],
                                      hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
        self.target_entropy = -torch.prod(
            torch.Tensor(action_space.shape).to(self.device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = Adam([self.log_alpha], lr=lr)
        self.policy = GaussianPolicy(seed, num_inputs, action_space.shape[0], \
                                         hidden_size, action_space).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)

        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1, qf2 = self.critic(state_batch, action_batch)
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        alpha_loss = -(self.log_alpha *
                       (log_pi + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()
        alpha_tlogs = self.alpha.clone()  # For TensorboardX logs

        soft_update(self.critic_target, self.critic, self.tau)
Beispiel #5
0
class SAC(object):
    def __init__(self, num_inputs, action_space, \
                 device, hidden_size, lr, gamma, tau, alpha):

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha

        self.device = device

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
        self.target_entropy = -torch.prod(
            torch.Tensor(action_space.shape).to(self.device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = Adam([self.log_alpha], lr=lr)
        self.policy = GaussianPolicy(num_inputs, action_space.shape[0], \
                                         hidden_size, action_space).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates, writer):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)

        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1, qf2 = self.critic(state_batch, action_batch)
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        # self.critic_optim.zero_grad()
        # qf1_loss.backward()
        # self.critic_optim.step()

        # self.critic_optim.zero_grad()
        # qf2_loss.backward()
        # self.critic_optim.step()
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        # alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        # self.alpha_optim.zero_grad()
        # alpha_loss.backward()
        # self.alpha_optim.step()

        # self.alpha = self.log_alpha.exp()
        # alpha_tlogs = self.alpha.clone() # For TensorboardX logs

        # writer.add_scalar('alpha', alpha_tlogs, updates)

        soft_update(self.critic_target, self.critic, self.tau)

    # Save model parameters
    def save_model(self,
                   actor_path='models/actor.pth',
                   critic_path='models/critic.pth'):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self,
                   actor_path='models/actor.pth',
                   critic_path='models/critic.pth'):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(
                torch.load(actor_path, map_location=torch.device('cpu')))
        if critic_path is not None:
            self.critic.load_state_dict(
                torch.load(critic_path, map_location=torch.device('cpu')))
Beispiel #6
0
class SacAgent:

    def __init__(self, env, log_dir, num_steps=3000000, batch_size=256,
                 lr=0.0003, hidden_units=[256, 256], memory_size=1e6,
                 gamma=0.99, tau=0.005, entropy_tuning=True, ent_coef=0.2,
                 multi_step=1, per=False, alpha=0.6, beta=0.4,
                 beta_annealing=0.0001, grad_clip=None, updates_per_step=1,
                 start_steps=10000, log_interval=10, target_update_interval=1,
                 eval_interval=1000, cuda=True, seed=0):
        self.env = env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        torch.backends.cudnn.deterministic = True  # It harms a performance.
        torch.backends.cudnn.benchmark = False

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.policy = GaussianPolicy(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic_target = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device).eval()

        # copy parameters of the learning network to the target network
        hard_update(self.critic_target, self.critic)
        # disable gradient calculations of the target network
        grad_false(self.critic_target)

        self.policy_optim = Adam(self.policy.parameters(), lr=lr)
        self.q1_optim = Adam(self.critic.Q1.parameters(), lr=lr)
        self.q2_optim = Adam(self.critic.Q2.parameters(), lr=lr)

        if entropy_tuning:
            # Target entropy is -|A|.
            self.target_entropy = -torch.prod(torch.Tensor(
                self.env.action_space.shape).to(self.device)).item()
            # We optimize log(alpha), instead of alpha.
            self.log_alpha = torch.zeros(
                1, requires_grad=True, device=self.device)
            self.alpha = self.log_alpha.exp()
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
        else:
            # fixed alpha
            self.alpha = torch.tensor(ent_coef).to(self.device)

        if per:
            # replay memory with prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = PrioritizedMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step,
                alpha=alpha, beta=beta, beta_annealing=beta_annealing)
        else:
            # replay memory without prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = MultiStepMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.num_steps = num_steps
        self.tau = tau
        self.per = per
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.gamma_n = gamma ** multi_step
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.log_interval = log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return len(self.memory) > self.batch_size and\
            self.steps >= self.start_steps

    def act(self, state):
        if self.start_steps > self.steps:
            action = self.env.action_space.sample()
        else:
            action = self.explore(state)
        return action

    def explore(self, state):
        # act with randomness
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action, _, _ = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)

    def exploit(self, state):
        # act without randomness
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            _, _, action = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)

    def calc_current_q(self, states, actions, rewards, next_states, dones):
        curr_q1, curr_q2 = self.critic(states, actions)
        return curr_q1, curr_q2

    def calc_target_q(self, states, actions, rewards, next_states, dones):
        with torch.no_grad():
            next_actions, next_entropies, _ = self.policy.sample(next_states)
            next_q1, next_q2 = self.critic_target(next_states, next_actions)
            next_q = torch.min(next_q1, next_q2) + self.alpha * next_entropies

        target_q = rewards + (1.0 - dones) * self.gamma_n * next_q

        return target_q

    def train_episode(self):
        self.episodes += 1
        episode_reward = 0.
        episode_steps = 0
        done = False
        state = self.env.reset()

        while not done:
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            self.steps += 1
            episode_steps += 1
            episode_reward += reward

            # ignore done if the agent reach time horizons
            # (set done=True only when the agent fails)
            if episode_steps >= self.env._max_episode_steps:
                masked_done = False
            else:
                masked_done = done

            if self.per:
                batch = to_batch(
                    state, action, reward, next_state, masked_done,
                    self.device)
                with torch.no_grad():
                    curr_q1, curr_q2 = self.calc_current_q(*batch)
                target_q = self.calc_target_q(*batch)
                error = torch.abs(curr_q1 - target_q).item()
                # We need to give true done signal with addition to masked done
                # signal to calculate multi-step rewards.
                self.memory.append(
                    state, action, reward, next_state, masked_done, error,
                    episode_done=done)
            else:
                # We need to give true done signal with addition to masked done
                # signal to calculate multi-step rewards.
                self.memory.append(
                    state, action, reward, next_state, masked_done,
                    episode_done=done)

            if self.is_update():
                for _ in range(self.updates_per_step):
                    self.learn()

            if self.steps % self.eval_interval == 0:
                self.evaluate()
                self.save_models()

            state = next_state

        # We log running mean of training rewards.
        self.train_rewards.append(episode_reward)

        if self.episodes % self.log_interval == 0:
            self.writer.add_scalar(
                'reward/train', self.train_rewards.get(), self.steps)

        print(f'episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  '
              f'reward: {episode_reward:<5.1f}')

    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        if self.per:
            # batch with indices and priority weights
            batch, indices, weights = \
                self.memory.sample(self.batch_size)
        else:
            batch = self.memory.sample(self.batch_size)
            # set priority weights to 1 when we don't use PER.
            weights = 1.

        q1_loss, q2_loss, errors, mean_q1, mean_q2 =\
            self.calc_critic_loss(batch, weights)
        policy_loss, entropies = self.calc_policy_loss(batch, weights)

        update_params(
            self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(
            self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(
            self.policy_optim, self.policy, policy_loss, self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies, weights)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()
            self.writer.add_scalar(
                'loss/alpha', entropy_loss.detach().item(), self.steps)

        if self.per:
            # update priority weights
            self.memory.update_priority(indices, errors.cpu().numpy())

        if self.learning_steps % self.log_interval == 0:
            self.writer.add_scalar(
                'loss/Q1', q1_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'loss/Q2', q2_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'loss/policy', policy_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'stats/alpha', self.alpha.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'stats/mean_Q1', mean_q1, self.learning_steps)
            self.writer.add_scalar(
                'stats/mean_Q2', mean_q2, self.learning_steps)
            self.writer.add_scalar(
                'stats/entropy', entropies.detach().mean().item(),
                self.learning_steps)

    def calc_critic_loss(self, batch, weights):
        curr_q1, curr_q2 = self.calc_current_q(*batch)
        target_q = self.calc_target_q(*batch)

        # TD errors for updating priority weights
        errors = torch.abs(curr_q1.detach() - target_q)
        # We log means of Q to monitor training.
        mean_q1 = curr_q1.detach().mean().item()
        mean_q2 = curr_q2.detach().mean().item()

        # Critic loss is mean squared TD errors with priority weights.
        q1_loss = torch.mean((curr_q1 - target_q).pow(2) * weights)
        q2_loss = torch.mean((curr_q2 - target_q).pow(2) * weights)
        return q1_loss, q2_loss, errors, mean_q1, mean_q2

    def calc_policy_loss(self, batch, weights):
        states, actions, rewards, next_states, dones = batch

        # We re-sample actions to calculate expectations of Q.
        sampled_action, entropy, _ = self.policy.sample(states)
        # expectations of Q with clipped double Q technique
        q1, q2 = self.critic(states, sampled_action)
        q = torch.min(q1, q2)

        # Policy objective is maximization of (Q + alpha * entropy) with
        # priority weights.
        policy_loss = torch.mean((- q - self.alpha * entropy) * weights)
        return policy_loss, entropy

    def calc_entropy_loss(self, entropy, weights):
        # Intuitively, we increse alpha when entropy is less than target
        # entropy, vice versa.
        entropy_loss = -torch.mean(
            self.log_alpha * (self.target_entropy - entropy).detach()
            * weights)
        return entropy_loss

    def evaluate(self):
        episodes = 10
        returns = np.zeros((episodes,), dtype=np.float32)

        for i in range(episodes):
            state = self.env.reset()
            episode_reward = 0.
            done = False
            while not done:
                action = self.exploit(state)
                next_state, reward, done, _ = self.env.step(action)
                episode_reward += reward
                state = next_state
            returns[i] = episode_reward

        mean_return = np.mean(returns)

        self.writer.add_scalar(
            'reward/test', mean_return, self.steps)
        print('-' * 60)
        print(f'Num steps: {self.steps:<5}  '
              f'reward: {mean_return:<5.1f}')
        print('-' * 60)

    def save_models(self):
        self.policy.save(os.path.join(self.model_dir, 'policy.pth'))
        self.critic.save(os.path.join(self.model_dir, 'critic.pth'))
        self.critic_target.save(
            os.path.join(self.model_dir, 'critic_target.pth'))

    def __del__(self):
        self.writer.close()
        self.env.close()
class BEARQL(object):
    def __init__(self, num_inputs, action_space, args):
        self.gamma = args.gamma
        self.tau = args.tau

        self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device)
        hard_update(self.critic_target, self.critic)

        self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        self.policy_target = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(device)
        hard_update(self.policy_target, self.policy)

        # dual_lambda
        self.dual_lambda = args.init_dual_lambda
        self.dual_step_size = args.dual_step_size
        self.cost_epsilon = args.cost_epsilon

        # coefficient_weight assigned to ensemble variance term
        self.coefficient_weight = args.coefficient_weight

        self.dual_grad_times = args.dual_grad_times

    # used in evaluation
    def select_action(self, state):
        # sample multiple policies and perform a greedy maximization of Q over these policies
        with torch.no_grad():
            state = torch.FloatTensor(state.reshape(1, -1)).repeat(10, 1).to(device)
            # state = torch.FloatTensor(state.reshape(1, -1)).to(device)
            action, _, mean = self.policy.sample(state)
            # q1, q2 = self.critic(state, action)
            q1, q2, q3 = self.critic(state, action)
            ind = q1.max(0)[1]
        return action[ind].cpu().data.numpy().flatten()
        # return action.cpu().data.numpy().flatten()

    # MMD functions
    def compute_kernel(self, x, y, sigma):
        batch_size = x.shape[0]
        x_size = x.shape[1]
        y_size = y.shape[1]
        dim = x.shape[2]
        tiled_x = x.view(batch_size, x_size, 1, dim).repeat([1, 1, y_size, 1])
        tiled_y = y.view(batch_size, 1, y_size, dim).repeat([1, x_size, 1, 1])
        return torch.exp(-(tiled_x - tiled_y).pow(2).sum(dim=3) / (2 * sigma))

    def compute_mmd(self, x, y, sigma=20.):
        x_kernel = self.compute_kernel(x, x, sigma)
        y_kernel = self.compute_kernel(y, y, sigma)
        xy_kernel = self.compute_kernel(x, y, sigma)
        square_mmd = x_kernel.mean((1, 2)) + y_kernel.mean((1, 2)) - 2 * xy_kernel.mean((1, 2))
        return square_mmd

    def train(self, prior, memory, batch_size, m=4, n=4):

        # Sample replay buffer / batch
        state_np, action_np, reward_np, next_state_np, mask_np = memory.sample(batch_size=batch_size)
        state_batch = torch.FloatTensor(state_np).to(device)
        next_state_batch = torch.FloatTensor(next_state_np).to(device)
        action_batch = torch.FloatTensor(action_np).to(device)
        reward_batch = torch.FloatTensor(reward_np).to(device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_np).to(device).unsqueeze(1)

        # Critic Training
        with torch.no_grad():

            # Duplicate state 10 times
            next_state_rep = torch.FloatTensor(np.repeat(next_state_np, 10, axis=0)).to(device)

            # Soft Clipped Double Q-learning
            next_state_action, _, _ = self.policy_target.sample(next_state_rep)
            target_Q1, target_Q2, target_Q3 = self.critic_target(next_state_rep, next_state_action)
            target_cat = torch.cat([target_Q1, target_Q2, target_Q3], 1)
            target_Q = 0.75 * target_cat.min(1)[0] + 0.25 * target_cat.max(1)[0]
            target_Q = target_Q.view(batch_size, -1).max(1)[0].view(-1, 1)

            next_q_value = reward_batch + mask_batch * self.gamma * target_Q

        qf1, qf2, qf3 = self.critic(state_batch, action_batch)  # ensemble of k Q-functions
        q_loss = F.mse_loss(qf1, next_q_value) + F.mse_loss(qf2, next_q_value) + F.mse_loss(qf3, next_q_value)

        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()

        # Actor Training
        with torch.no_grad():
            state_rep_m = torch.FloatTensor(np.repeat(state_np, m, axis=0)).to(device)
            state_rep_n = torch.FloatTensor(np.repeat(state_np, n, axis=0)).to(device)

        for i in range(self.dual_grad_times):
            prior_a_rep, _, _ = prior.sample(state_rep_n)
            prior_a_rep = prior_a_rep.view(batch_size, n, -1)
            pi_rep, _, _ = self.policy.sample(state_rep_m)
            pi_rep = pi_rep.view(batch_size, m, -1)
            mmd_dist = self.compute_mmd(prior_a_rep, pi_rep)

            pi, _, _ = self.policy.sample(state_batch)
            qf1_pi, qf2_pi, qf3_pi = self.critic(state_batch, pi)
            qf_cat = torch.cat([qf1_pi, qf2_pi, qf3_pi], 1)
            # min_qf_pi = torch.min(qf1_pi, qf2_pi)  # used in TD3
            # use conservative estimate of Q as used in BEAR
            qf_mean = qf_cat.mean(1)
            qf_var = qf_cat.var(1)
            min_qf_pi = qf_mean - self.coefficient_weight * qf_var.sqrt()  # used in BEAR

            policy_loss = -(min_qf_pi - self.dual_lambda*mmd_dist).mean()

            self.policy_optim.zero_grad()
            policy_loss.backward()
            self.policy_optim.step()

            # Dual Lambda Training
            self.dual_gradients = mmd_dist.mean().item() - self.cost_epsilon
            self.dual_lambda += self.dual_step_size * self.dual_gradients
            self.dual_lambda = np.clip(self.dual_lambda, np.power(np.e, -5), np.power(np.e, 10))

        # Update Target Networks
        soft_update(self.critic_target, self.critic, self.tau)
        soft_update(self.policy_target, self.policy, self.tau)

        return q_loss.item(), policy_loss.item(), self.dual_lambda, mmd_dist.mean().item()

    # Save model parameters
    def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/BEAR_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/BEAR_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
Beispiel #8
0
class SAC(object):
    def __init__(self, input_space, action_space, args):

        self.use_expert = args.use_expert
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.action_range = [action_space.low, action_space.high]
        self.policy_type = args.policy

        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        # self.device = torch.device("cuda" if args.cuda else "cpu")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # print(torch.cuda.is_available())
        # print(torch.cuda.current_device())
        # print(torch.cuda.device(0))
        # print(torch.cuda.device_count())
        # print(torch.cuda.get_device_name())
        # print(torch.backends.cudnn.version())
        # print(torch.backends.cudnn.is_available())

        self.critic = QNetwork(input_space, action_space.shape[0], args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(input_space, action_space.shape[0], args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(input_space, action_space.shape[0], args.hidden_size, action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            raise ValueError("Not supper another type yet.")

        # SAC_V
        # self.value = ValueNetwork(input_space).to(device=self.device)
        # self.value_target = ValueNetwork(input_space).to(self.device)
        # self.value_optim = Adam(self.value.parameters(), lr=args.lr)
        # hard_update(self.value_target, self.value)

        # self.policy = GaussianPolicy(input_space, action_space.shape[0], args.hidden_size).to(self.device)
        # self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state_map = torch.FloatTensor(state['map']).to(self.device).unsqueeze(0)
        state_lidar = torch.FloatTensor(state['lidar']).to(self.device).unsqueeze(0)
        state_goal = torch.FloatTensor(state['goal']).to(self.device).unsqueeze(0)
        state_plan_len = torch.FloatTensor(state['plan_len']).to(self.device).unsqueeze(0)
        state_robot_info= torch.FloatTensor(state['robot_info']).to(self.device).unsqueeze(0)
        state = {'map': state_map, 'lidar': state_lidar, 'goal': state_goal, 'plan_len': state_plan_len, 'robot_info': state_robot_info}
        if eval is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        action = action.detach().cpu().numpy()[0]
        # return self.rescale_action(action)
        return action

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
                (self.action_range[1] + self.action_range[0]) / 2.0

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        if not self.use_expert:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
        else:
            state_batch, action_batch, reward_batch, next_state_batch, mask_batch, s_e_batch, a_e_batch = memory.sample(batch_size=batch_size, use_expert=True)

        # State is array of dictionary like [{"map":value, "lidar":value, "goal":value}, ...]
        # So, convert list to dict below:
        _state_batch = {'map':[], 'lidar':[], 'goal':[], 'plan_len':[], 'robot_info':[]}
        _next_state_batch = {'map':[], 'lidar':[], 'goal':[], 'plan_len':[], 'robot_info':[]}
        _state_expert_batch = {'map':[], 'lidar':[], 'goal':[], 'plan_len':[], 'robot_info':[]}
        for s in state_batch:
            _state_batch['map'].append(s['map'])
            _state_batch['lidar'].append(s['lidar'])
            _state_batch['goal'].append(s['goal'])
            _state_batch['plan_len'].append(s['plan_len'])
            _state_batch['robot_info'].append(s['robot_info'])

        for s in next_state_batch:
            _next_state_batch['map'].append(s['map'])
            _next_state_batch['lidar'].append(s['lidar'])
            _next_state_batch['goal'].append(s['goal'])
            _next_state_batch['plan_len'].append(s['plan_len'])
            _next_state_batch['robot_info'].append(s['robot_info'])

        if self.use_expert:
            for s in s_e_batch:
                _state_expert_batch['map'].append(s['map'])
                _state_expert_batch['lidar'].append(s['lidar'])
                _state_expert_batch['goal'].append(s['goal'])
                _state_expert_batch['plan_len'].append(s['plan_len'])
                _state_expert_batch['robot_info'].append(s['robot_info'])

        _state_batch['map'] = torch.FloatTensor(_state_batch['map']).to(self.device)
        _state_batch['lidar'] = torch.FloatTensor(_state_batch['lidar']).to(self.device)
        _state_batch['goal'] = torch.FloatTensor(_state_batch['goal']).to(self.device)
        _state_batch['plan_len'] = torch.FloatTensor(_state_batch['plan_len']).to(self.device)
        _state_batch['robot_info'] = torch.FloatTensor(_state_batch['robot_info']).to(self.device)
        _next_state_batch['map'] = torch.FloatTensor(_next_state_batch['map']).to(self.device)
        _next_state_batch['lidar'] = torch.FloatTensor(_next_state_batch['lidar']).to(self.device)
        _next_state_batch['goal'] = torch.FloatTensor(_next_state_batch['goal']).to(self.device)
        _next_state_batch['plan_len'] = torch.FloatTensor(_next_state_batch['plan_len']).to(self.device)
        _next_state_batch['robot_info'] = torch.FloatTensor(_next_state_batch['robot_info']).to(self.device)
        if self.use_expert:
            _state_expert_batch['map'] = torch.FloatTensor(_state_expert_batch['map']).to(self.device)
            _state_expert_batch['lidar'] = torch.FloatTensor(_state_expert_batch['lidar']).to(self.device)
            _state_expert_batch['goal'] = torch.FloatTensor(_state_expert_batch['goal']).to(self.device)
            _state_expert_batch['plan_len'] = torch.FloatTensor(_state_expert_batch['plan_len']).to(self.device)
            _state_expert_batch['robot_info'] = torch.FloatTensor(_state_expert_batch['robot_info']).to(self.device)
            _action_expert_batch = torch.FloatTensor(a_e_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        # SAC_V
        # with torch.no_grad():
        #     vf_next_target = self.value_target(_new_next_state_batch)
        #     next_q_value = reward_batch + mask_batch * self.gamma * (vf_next_target)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(_next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(_next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
        qf1, qf2 = self.critic(_state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        # Update Policy
        if not self.use_expert:
            pi, log_pi, _ = self.policy.sample(_state_batch)

            qf1_pi, qf2_pi = self.critic(_state_batch, pi)
            min_qf_pi = torch.min(qf1_pi, qf2_pi)
        else:
            pi, log_pi, _ = self.policy.sample(_state_expert_batch)

            qf1_pi, qf2_pi = self.critic(_state_expert_batch, pi)
            min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()

        # # SAC_V
        # # Regularization Loss
        # reg_loss = 0.001 * (mean.pow(2).mean() + log_std.pow(2).mean())
        # policy_loss += reg_loss

        # self.policy_optim.zero_grad()
        # policy_loss.backward()
        # self.policy_optim.step()

        # # Update Value
        # if not self.use_expert:
        #     vf = self.value(_new_state_batch)
        # else:
        #     vf = self.value(_new_s_e_batch)
        
        # with torch.no_grad():
        #     vf_target = min_qf_pi - (self.alpha * log_pi)

        # vf_loss = F.mse_loss(vf, vf_target) # JV = 𝔼(st)~D[0.5(V(st) - (𝔼at~π[Q(st,at) - α * logπ(at|st)]))^2]

        # self.value_optim.zero_grad()
        # vf_loss.backward()
        # self.value_optim.step()

        # if updates % self.target_update_interval == 0:
        #     soft_update(self.value_target, self.value, self.tau)

        # return vf_loss.item(), qf1_loss.item(), qf2_loss.item(), policy_loss.item()

    # Save model parameters
    def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to\n {}\n, {}\n'.format(actor_path, critic_path))

        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from\n {}\n, {}\n'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
Beispiel #9
0
class SACAgent():
    def __init__(self, action_size, state_size, config):
        self.action_size = action_size
        self.state_size = state_size
        self.min_action = config["min_action"]
        self.max_action = config["max_action"]
        self.seed = config["seed"]
        self.tau = config["tau"]
        self.gamma = config["gamma"]
        self.batch_size = config["batch_size"]
        if not torch.cuda.is_available():
            config["device"] == "cpu"
        self.device = config["device"]
        self.eval = config["eval"]
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        self.vid_path = config["vid_path"]
        print("actions size ", action_size)
        print("actions min ", self.min_action)
        print("actions max ", self.max_action)
        self.critic = QNetwork(state_size, action_size, config["fc1_units"], config["fc2_units"]).to(self.device)
        self.q_optim = torch.optim.Adam(self.critic.parameters(), config["lr_critic"])
        self.target_critic = QNetwork(state_size, action_size, config["fc1_units"], config["fc2_units"]).to(self.device)
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha = self.log_alpha.exp()
        self.alpha_optim = Adam([self.log_alpha], lr=config["lr_alpha"])
        #self.policy = SACActor(state_size, action_size).to(self.device)
        self.policy = GaussianPolicy(state_size, action_size, 256).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=config["lr_policy"])
        self.max_timesteps = config["max_episodes_steps"]
        self.episodes = config["episodes"]
        self.memory = ReplayBuffer((state_size, ), (action_size, ), config["buffer_size"], self.device)
        pathname = config["seed"]
        tensorboard_name = str(config["res_path"]) + '/runs/' + str(pathname)
        self.writer = SummaryWriter(tensorboard_name)
        self.steps= 0
        self.target_entropy = -torch.prod(torch.Tensor(action_size).to(self.device)).item()

    def act(self, state, evaluate=False):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
            if evaluate is False:
                action, _, _ = self.policy.sample(state)
            else:
                _, _, action = self.policy.sample(state)
            # action = np.clip(action, self.min_action, self.max_action)
            action = action.cpu().numpy()[0]
        #print(action)
        return action
    
    def train_agent(self):
        env = gym.make("LunarLanderContinuous-v2")
        average_reward = 0
        scores_window = deque(maxlen=100)
        s = 0
        t0 = time.time()
        for i_epiosde in range(self.episodes):
            episode_reward = 0
            state = env.reset()
            for t in range(self.max_timesteps):
                s += 1
                action = self.act(state)
                next_state, reward, done, _ = env.step(action)
                episode_reward += reward
                if i_epiosde > 3:
                    self.learn()
                self.memory.add(state, reward, action, next_state, done)
                state = next_state
                if done:
                    scores_window.append(episode_reward)
                    break
            if i_epiosde % self.eval == 0:
                self.eval_policy()
            ave_reward = np.mean(scores_window)
            print("Epiosde {} Steps {} Reward {} Reward averge{} Time {}".format(i_epiosde, t, episode_reward, np.mean(scores_window), time_format(time.time() - t0)))
            self.writer.add_scalar('Aver_reward', ave_reward, self.steps)
            
    
    def learn(self):
        self.steps += 1
        states, rewards, actions, next_states, dones = self.memory.sample(self.batch_size)
        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_states)
            target_q1, target_q2 = self.target_critic(next_states, next_state_action)
            target_min = torch.min(target_q1, target_q2)
            q_target = target_min - (self.alpha * next_state_log_pi)
            next_q_value = rewards + (1 - dones) * self.gamma * q_target
        
        qf1, qf2 = self.critic(states, actions)
        # --------------------------update-q--------------------------------------------------------
        loss = F.mse_loss(qf1, next_q_value) + F.mse_loss(qf2, next_q_value) 
        self.q_optim.zero_grad() 
        loss.backward()
        self.q_optim.step()
        self.writer.add_scalar('loss/q', loss, self.steps)


        # --------------------------update-policy--------------------------------------------------------
        pi, log_pi, _ = self.policy.sample(states)
        q_pi1, q_pi2 = self.critic(states, pi)
        min_q_values = torch.min(q_pi1, q_pi2)
        policy_loss = ((self.alpha * log_pi) - min_q_values).mean()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        self.writer.add_scalar('loss/policy', policy_loss, self.steps)
        
        # --------------------------update-alpha--------------------------------------------------------
        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
        
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.writer.add_scalar('loss/alpha', alpha_loss, self.steps)

        self.soft_udapte(self.critic, self.target_critic)
        self.alpha = self.log_alpha.exp()


    
    def soft_udapte(self, online, target):
        for param, target_parm in zip(online.parameters(), target.parameters()):
            target_parm.data.copy_(self.tau * param.data + (1 - self.tau) * target_parm.data)

    def eval_policy(self, eval_episodes=4):
        env = gym.make("LunarLanderContinuous-v2")
        # env  = wrappers.Monitor(env, str(self.vid_path) + "/{}".format(self.steps), video_callable=lambda episode_id: True,force=True)
        average_reward = 0
        scores_window = deque(maxlen=100)
        for i_epiosde in range(eval_episodes):
            print("Eval Episode {} of {} ".format(i_epiosde, eval_episodes))
            episode_reward = 0
            state = env.reset()
            while True: 
                action = self.act(state, evaluate=True)
                state, reward, done, _ = env.step(action)
                episode_reward += reward
                if done:
                    break
            scores_window.append(episode_reward)
        average_reward = np.mean(scores_window)
        self.writer.add_scalar('Eval_reward', average_reward, self.steps)
Beispiel #10
0
class Off_policy(Algo):
    def __init__(self):
        super(Off_policy, self).__init__()
        self.memory = Replay_buffer(capacity=p.exploitory_policy_memory_size)
        self.exploratory_policy = GaussianPolicy(
            self.state_space, self.action_space).to(self.device)
        self.exploratory_Q = QNet(self.state_space,
                                  self.action_space).to(self.device)
        self.exploratory_Q_target = QNet(self.state_space,
                                         self.action_space).to(self.device)
        self.exploratory_policy_optim = Adam(
            self.exploratory_policy.parameters(), lr=p.lr)
        self.exploratory_Q_optim = Adam(self.exploratory_Q.parameters(),
                                        lr=p.lr)

        self.target_update(self.exploratory_policy, self.exploitory_policy,
                           1.0)

        self.kl_normalizer = Normalizer(1)
        self.ex_rewards_normalizer = Normalizer(1)

    def start(self):
        total_numsteps = 0

        for episode in itertools.count(1):
            episode_rewards = 0.0
            episode_steps = 0
            done = False
            state = self.env.reset()

            while not done:
                episode_steps += 1
                if p.random_steps > total_numsteps:
                    action = self.env.action_space.sample()
                else:
                    norm_state = self.obs_normalizer.normalize(state)
                    action = self.select_action(norm_state,
                                                self.exploratory_policy)

                if len(self.memory) > p.exploitory_batch_size and len(
                        self.memory) > p.exploratory_batch_size:
                    for i in range(p.exploitory_policy_updates_per_steps):
                        qf1_loss, qf2_loss, policy_loss, alpha_loss, alpha, ex_reward_model_loss = self.update_exploitory_policy(
                            self.memory)
                        if episode % p.exploitory_target_update_interval == 0:
                            self.target_update(self.exploitory_Q_target,
                                               self.exploitory_Q, p.tau)

                    for i in range(p.exploratory_policy_updates_per_steps):
                        ex_qf1_loss, ex_qf2_loss, ex_policy_loss, divergence_loss = self.update_exploratory_policy(
                            self.memory)
                        if episode % p.exploratory_target_update_interval == 0:
                            self.target_update(self.exploratory_Q_target,
                                               self.exploratory_Q, p.tau)

                next_state, reward, done, _ = self.env.step(action)
                total_numsteps += 1
                episode_rewards += reward

                # Ignore the done signal if it comes from hitting the time horizon.
                mask = 1.0 if episode_steps == self.env._max_episode_steps else float(
                    not done)

                self.memory.push((state, action, reward, next_state, mask))
                self.obs_normalizer.update(state)
                state = next_state

            if episode % p.test_freq == 0:
                average_rewards, average_episode_steps = self.test_current_policy(
                )
                try:

                    data = {
                        'average_rewards': average_rewards,
                        'total_numsteps': total_numsteps,
                        'average_episode_steps': average_episode_steps,
                        'qf1_loss': qf1_loss,
                        'qf2_loss': qf2_loss,
                        'exploitory_policy_loss': policy_loss,
                        'alpha_loss': alpha_loss,
                        'alpha_value': alpha,
                        'ex_qf1_loss': ex_qf1_loss,
                        'ex_qf2_loss': ex_qf2_loss,
                        'ex_policy_loss': ex_policy_loss,
                        'ex_reward_model_loss': ex_reward_model_loss,
                        'divergence_loss': divergence_loss
                    }

                    self.log(data)
                except UnboundLocalError:
                    pass

            if total_numsteps > p.max_numsteps:
                self.env.close()
                self.writer.close()
                break

    def update_exploratory_policy(self, memory):
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            p.exploitory_batch_size)
        state_batch, next_state_batch = self.obs_normalizer.normalize(
            state_batch), self.obs_normalizer.normalize(next_state_batch)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            ex_rewards = self.ex_reward_model.get_reward(
                state_batch, next_state_batch)
            ex_rewards = ex_rewards.unsqueeze(1).cpu().numpy()
            ex_reward_batch = self.ex_rewards_normalizer.normalize(ex_rewards)
            self.ex_rewards_normalizer.update(ex_rewards)
            ex_reward_batch = torch.FloatTensor(ex_reward_batch).to(
                self.device)

            ex_next_state_action, ex_next_state_log_pi, _ = self.exploratory_policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.exploratory_Q_target(
                next_state_batch, ex_next_state_action)
            '''
			ex_mean_actions, ex_log_std = self.exploratory_policy(next_state_batch)
			mean_actions, log_std = self.exploitory_policy(next_state_batch)
			ex_normal = Normal(ex_mean_actions, ex_log_std.exp())
			normal = Normal(mean_actions, log_std.exp())
			kl_div = torch.distributions.kl_divergence(ex_normal, normal).mean(1).unsqueeze(1)
			'''

            ex_next_state_log_prob = torch.clamp(
                self.exploratory_policy.get_logprob(next_state_batch,
                                                    ex_next_state_action),
                min=p.log_std_min,
                max=p.log_std_max)
            next_state_log_prob = torch.clamp(
                self.exploitory_policy.get_logprob(next_state_batch,
                                                   ex_next_state_action),
                min=p.log_std_min,
                max=p.log_std_max)

            kl_div = (ex_next_state_log_prob -
                      next_state_log_prob).mean(1).unsqueeze(1)

            min_qf_next_target = p.ex_alpha * (
                torch.min(qf1_next_target, qf2_next_target) -
                (p.alpha * ex_next_state_log_pi)) - kl_div
            next_q_value = ex_reward_batch + mask_batch * p.gamma * (
                min_qf_next_target)

        qf1, qf2 = self.exploratory_Q(state_batch, action_batch)

        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss

        self.exploratory_Q_optim.zero_grad()
        qf_loss.backward()
        self.exploratory_Q_optim.step()

        ex_pi, ex_log_pi, _ = self.exploratory_policy.sample(state_batch)

        qf1_pi, qf2_pi = self.exploratory_Q(state_batch, ex_pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        '''
		ex_mean_actions, ex_log_std = self.exploratory_policy(state_batch)
		mean_actions, log_std = self.exploitory_policy(state_batch)
		ex_normal = Normal(ex_mean_actions, ex_log_std.exp())
		normal = Normal(mean_actions, log_std.exp())
		kl_div = torch.distributions.kl_divergence(ex_normal, normal).mean(1).unsqueeze(1)
		'''

        ex_state_log_prob = torch.clamp(self.exploratory_policy.get_logprob(
            state_batch, ex_pi),
                                        min=p.log_std_min,
                                        max=p.log_std_max)
        with torch.no_grad():
            state_log_prob = torch.clamp(self.exploitory_policy.get_logprob(
                state_batch, ex_pi),
                                         min=p.log_std_min,
                                         max=p.log_std_max)
        kl_div = (ex_state_log_prob - state_log_prob).mean(1).unsqueeze(1)

        policy_loss = (p.ex_alpha * ((p.alpha * ex_log_pi) - min_qf_pi) +
                       kl_div).mean()

        self.exploratory_policy_optim.zero_grad()
        policy_loss.backward()
        self.exploratory_policy_optim.step()

        ex_alpha_loss = torch.Tensor([0.0])

        if settings.automatic_ex_entropy_tuning:
            ex_alpha_loss = -(
                self.ex_log_alpha *
                (ex_log_pi + self.ex_target_entropy).detach()).mean()
            self.ex_alpha_optim.zero_grad()
            ex_alpha_loss.backward()
            self.ex_alpha_optim.step()

            p.ex_alpha = self.ex_log_alpha.exp().item()

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), kl_div.mean().item()
Beispiel #11
0
class Algo:
    def __init__(self):
        #Creating environment
        self.env = gym.make(settings.env_name)
        self.env.seed(settings.seed)
        self.env.action_space.seed(settings.seed)

        self.state_space = self.env.observation_space.shape[0]
        self.action_space = self.env.action_space.shape[0]

        self.obs_normalizer = Normalizer(self.state_space)

        self.device = torch.device(settings.device)
        self.writer = SummaryWriter(
            'runs/' + settings.env_name + "_" + settings.algo +
            '_{}_{}_{}'.format(p.alpha, p.ex_alpha, settings.seed))

        #Initializing common networks and their optimizers
        self.exploitory_policy = GaussianPolicy(
            self.state_space, self.action_space).to(self.device)
        self.exploitory_Q = QNet(self.state_space,
                                 self.action_space).to(self.device)
        self.exploitory_Q_target = QNet(self.state_space,
                                        self.action_space).to(self.device)
        self.exploitory_policy_optim = Adam(
            self.exploitory_policy.parameters(), lr=p.lr)
        self.exploitory_Q_optim = Adam(self.exploitory_Q.parameters(), lr=p.lr)

        self.target_update(self.exploitory_Q_target, self.exploitory_Q, 1.0)

        p.alpha = torch.Tensor([p.alpha]).to(self.device)
        if settings.automatic_entropy_tuning:
            self.target_entropy = -torch.prod(
                torch.Tensor(self.env.action_space.shape).to(
                    self.device)).item()
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=p.lr)

        if settings.automatic_ex_entropy_tuning:
            self.ex_target_entropy = -torch.prod(
                torch.Tensor(self.env.action_space.shape).to(
                    self.device)).item()
            self.ex_log_alpha = torch.zeros(1,
                                            requires_grad=True,
                                            device=self.device)
            self.ex_alpha_optim = Adam([self.log_alpha], lr=p.lr)

        if settings.reward_model == 'novelty':
            self.ex_reward_model = Novelty(self.state_space, self.device)

    def target_update(self, target, source, tau=p.tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def update_exploitory_policy(self, memory):
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            p.exploitory_batch_size)
        state_batch, next_state_batch = self.obs_normalizer.normalize(
            state_batch), self.obs_normalizer.normalize(next_state_batch)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.exploitory_policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.exploitory_Q_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target, qf2_next_target) - p.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * p.gamma * (
                min_qf_next_target)

        qf1, qf2 = self.exploitory_Q(state_batch, action_batch)

        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss

        self.exploitory_Q_optim.zero_grad()
        qf_loss.backward()
        self.exploitory_Q_optim.step()

        pi, log_pi, _ = self.exploitory_policy.sample(state_batch)
        qf1_pi, qf2_pi = self.exploitory_Q(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        policy_loss = ((p.alpha * log_pi) - min_qf_pi).mean()

        self.exploitory_policy_optim.zero_grad()
        policy_loss.backward()
        self.exploitory_policy_optim.step()

        alpha_loss = torch.Tensor([0.0])

        if settings.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            p.alpha = self.log_alpha.exp().item()

        ex_reward_model_loss = self.ex_reward_model.update(memory)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), p.alpha, ex_reward_model_loss

    def test_current_policy(self):
        avg_reward = 0
        avg_steps = 0
        avg_ex_rewards = 0

        for episode in range(p.testing_episodes):
            episode_steps = 0
            state = self.env.reset()
            episode_rewards = 0
            episode_ex_rewards = 0
            done = False

            while not done:
                episode_steps += 1
                norm_state = self.obs_normalizer.normalize(state)
                action = self.select_action(norm_state,
                                            self.exploitory_policy,
                                            evaluate=True)
                next_state, reward, done, _ = self.env.step(action)
                episode_rewards += reward

                state = next_state

            avg_reward += episode_rewards
            avg_ex_rewards += episode_ex_rewards
            avg_steps += episode_steps

        avg_reward = avg_reward / p.testing_episodes
        avg_ex_rewards = avg_ex_rewards / p.testing_episodes
        avg_steps = avg_steps / p.testing_episodes

        return avg_reward, avg_steps

    def select_action(self, state, policy, evaluate=False):
        with torch.no_grad():
            try:
                state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
                if evaluate is False:
                    action, log_prob, _ = policy.sample(state)
                else:
                    _, log_prob, action = policy.sample(state)

                return action.cpu().numpy()[0]

            except:
                state = state.unsqueeze(0)
                if evaluate is False:
                    action, log_prob, _ = policy.sample(state)
                else:
                    _, log_prob, action = policy.sample(state)

                return action

    def log(self, data):
        for key in data.keys():
            if key != "total_numsteps":
                self.writer.add_scalar(
                    key.split('_')[-1] + "/" + key, data[key],
                    data['total_numsteps'])
        print("Total number of Steps: {} \t Average reward per episode: {}".
              format(data['total_numsteps'], round(data['average_rewards'],
                                                   1)))

    def start(self):
        raise NotImplementedError