Exemple #1
0
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.num_inputs = num_inputs
        self.action_space = action_space.shape[0]
        self.gamma = args.gamma
        self.tau = args.tau

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.critic = QNetwork(self.num_inputs, self.action_space,
                               args.hidden_size)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        if self.policy_type == "Gaussian":
            self.alpha = args.alpha
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
            else:
                pass

            self.policy = GaussianPolicy(self.num_inputs, self.action_space,
                                         args.hidden_size)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

            self.value = ValueNetwork(self.num_inputs, args.hidden_size)
            self.value_target = ValueNetwork(self.num_inputs, args.hidden_size)
            self.value_optim = Adam(self.value.parameters(), lr=args.lr)
            hard_update(self.value_target, self.value)
        else:
            self.policy = DeterministicPolicy(self.num_inputs,
                                              self.action_space,
                                              args.hidden_size)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

            self.critic_target = QNetwork(self.num_inputs, self.action_space,
                                          args.hidden_size)
            hard_update(self.critic_target, self.critic)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).unsqueeze(0)
        if eval == False:
            self.policy.train()
            action, _, _, _, _ = self.policy.sample(state)
        else:
            self.policy.eval()
            _, _, _, action, _ = self.policy.sample(state)
            if self.policy_type == "Gaussian":
                action = torch.tanh(action)
            else:
                pass
        #action = torch.tanh(action)
        action = action.detach().cpu().numpy()
        return action[0]

    def update_parameters(self, state_batch, action_batch, reward_batch,
                          next_state_batch, mask_batch, updates):
        state_batch = torch.FloatTensor(state_batch)
        next_state_batch = torch.FloatTensor(next_state_batch)
        action_batch = torch.FloatTensor(action_batch)
        reward_batch = torch.FloatTensor(reward_batch).unsqueeze(1)
        mask_batch = torch.FloatTensor(np.float32(mask_batch)).unsqueeze(1)
        """
        Use two Q-functions to mitigate positive bias in the policy improvement step that is known
        to degrade performance of value based methods. Two Q-functions also significantly speed
        up training, especially on harder task.
        """
        expected_q1_value, expected_q2_value = self.critic(
            state_batch, action_batch)
        new_action, log_prob, _, mean, log_std = self.policy.sample(
            state_batch)

        if self.policy_type == "Gaussian":
            if self.automatic_entropy_tuning:
                """
                Alpha Loss
                """
                alpha_loss = -(
                    self.log_alpha *
                    (log_prob + self.target_entropy).detach()).mean()
                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp()
                alpha_logs = self.alpha.clone()  # For TensorboardX logs
            else:
                alpha_loss = torch.tensor(0.)
                alpha_logs = self.alpha  # For TensorboardX logs
            """
            Including a separate function approximator for the soft value can stabilize training.
            """
            expected_value = self.value(state_batch)
            target_value = self.value_target(next_state_batch)
            next_q_value = reward_batch + mask_batch * self.gamma * (
                target_value).detach()
        else:
            """
            There is no need in principle to include a separate function approximator for the state value.
            We use a target critic network for deterministic policy and eradicate the value value network completely.
            """
            alpha_loss = torch.tensor(0.)
            alpha_logs = self.alpha  # For TensorboardX logs
            next_state_action, _, _, _, _, = self.policy.sample(
                next_state_batch)
            target_critic_1, target_critic_2 = self.critic_target(
                next_state_batch, next_state_action)
            target_critic = torch.min(target_critic_1, target_critic_2)
            next_q_value = reward_batch + mask_batch * self.gamma * (
                target_critic).detach()
        """
        Soft Q-function parameters can be trained to minimize the soft Bellman residual
        JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        ∇JQ = ∇Q(st,at)(Q(st,at) - r(st,at) - γV(target)(st+1))
        """
        q1_value_loss = F.mse_loss(expected_q1_value, next_q_value)
        q2_value_loss = F.mse_loss(expected_q2_value, next_q_value)
        q1_new, q2_new = self.critic(state_batch, new_action)
        expected_new_q_value = torch.min(q1_new, q2_new)

        if self.policy_type == "Gaussian":
            """
            Including a separate function approximator for the soft value can stabilize training and is convenient to 
            train simultaneously with the other networks
            Update the V towards the min of two Q-functions in order to reduce overestimation bias from function approximation error.
            JV = 𝔼st~D[0.5(V(st) - (𝔼at~π[Qmin(st,at) - α * log π(at|st)]))^2]
            ∇JV = ∇V(st)(V(st) - Q(st,at) + (α * logπ(at|st)))
            """
            next_value = expected_new_q_value - (self.alpha * log_prob)
            value_loss = F.mse_loss(expected_value, next_value.detach())
        else:
            pass
        """
        Reparameterization trick is used to get a low variance estimator
        f(εt;st) = action sampled from the policy
        εt is an input noise vector, sampled from some fixed distribution
        Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        ∇Jπ = ∇log π + ([∇at (α * logπ(at|st)) − ∇at Q(st,at)])∇f(εt;st)
        """
        policy_loss = ((self.alpha * log_prob) - expected_new_q_value).mean()

        # Regularization Loss
        mean_loss = 0.001 * mean.pow(2).mean()
        std_loss = 0.001 * log_std.pow(2).mean()

        policy_loss += mean_loss + std_loss

        self.critic_optim.zero_grad()
        q1_value_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        q2_value_loss.backward()
        self.critic_optim.step()

        if self.policy_type == "Gaussian":
            self.value_optim.zero_grad()
            value_loss.backward()
            self.value_optim.step()
        else:
            value_loss = torch.tensor(0.)

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        """
        We update the target weights to match the current value function weights periodically
        Update target parameter after every n(args.target_update_interval) updates
        """
        if updates % self.target_update_interval == 0 and self.policy_type == "Deterministic":
            soft_update(self.critic_target, self.critic, self.tau)

        elif updates % self.target_update_interval == 0 and self.policy_type == "Gaussian":
            soft_update(self.value_target, self.value, self.tau)
        return value_loss.item(), q1_value_loss.item(), q2_value_loss.item(
        ), policy_loss.item(), alpha_loss.item(), alpha_logs

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None,
                   value_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        if value_path is None:
            value_path = "models/sac_value_{}_{}".format(env_name, suffix)
        print('Saving models to {}, {} and {}'.format(actor_path, critic_path,
                                                      value_path))
        torch.save(self.value.state_dict(), value_path)
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path, value_path):
        print('Loading models from {}, {} and {}'.format(
            actor_path, critic_path, value_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
        if value_path is not None:
            self.value.load_state_dict(torch.load(value_path))
Exemple #2
0
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.num_inputs = num_inputs
        self.action_space = action_space.shape[0]
        self.gamma = args.gamma
        self.tau = args.tau

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(self.num_inputs, self.action_space,
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        if self.policy_type == "Gaussian":
            self.alpha = args.alpha
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(self.num_inputs, self.action_space,
                                         args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

            self.value = ValueNetwork(self.num_inputs,
                                      args.hidden_size).to(self.device)
            self.value_target = ValueNetwork(self.num_inputs,
                                             args.hidden_size).to(self.device)
            self.value_optim = Adam(self.value.parameters(), lr=args.lr)
            hard_update(self.value_target, self.value)
        else:
            self.policy = DeterministicPolicy(self.num_inputs,
                                              self.action_space,
                                              args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

            self.critic_target = QNetwork(self.num_inputs, self.action_space,
                                          args.hidden_size).to(self.device)
            hard_update(self.critic_target, self.critic)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            self.policy.train()
            action, _, _ = self.policy.sample(state)
        else:
            self.policy.eval()
            _, _, action = self.policy.sample(state)
        action = action.detach().cpu().numpy()
        return action[0]

    def update_parameters(self, state_batch, action_batch, reward_batch,
                          next_state_batch, mask_batch, updates):
        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        pi, log_pi, _ = self.policy.sample(state_batch)

        if self.policy_type == "Gaussian":
            if self.automatic_entropy_tuning:
                alpha_loss = -(self.log_alpha *
                               (log_pi + self.target_entropy).detach()).mean()
                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()
                self.alpha = self.log_alpha.exp()
                alpha_logs = torch.tensor(self.alpha)  # For TensorboardX logs
            else:
                alpha_loss = torch.tensor(0.).to(self.device)
                alpha_logs = torch.tensor(self.alpha)  # For TensorboardX logs

            vf = self.value(
                state_batch
            )  # separate function approximator for the soft value can stabilize training.
            with torch.no_grad():
                vf_next_target = self.value_target(next_state_batch)
                next_q_value = reward_batch + mask_batch * self.gamma * (
                    vf_next_target)
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_logs = self.alpha  # For TensorboardX logs
            with torch.no_grad():
                next_state_action, _, _, _, _, = self.policy.sample(
                    next_state_batch)
                # Use a target critic network for deterministic policy and eradicate the value value network completely.
                qf1_next_target, qf2_next_target = self.critic_target(
                    next_state_batch, next_state_action)
                min_qf_next_target = torch.min(qf1_next_target,
                                               qf2_next_target)
                next_q_value = reward_batch + mask_batch * self.gamma * (
                    min_qf_next_target)

        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        if self.policy_type == "Gaussian":
            vf_target = min_qf_pi - (self.alpha * log_pi)
            value_loss = F.mse_loss(
                vf, vf_target.detach()
            )  # JV = 𝔼st~D[0.5(V(st) - (𝔼at~π[Qmin(st,at) - α * log π(at|st)]))^2]

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        # Regularization Loss
        # mean_loss = 0.001 * mean.pow(2).mean()
        # std_loss = 0.001 * log_std.pow(2).mean()

        # policy_loss += mean_loss + std_loss

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        if self.policy_type == "Gaussian":
            self.value_optim.zero_grad()
            value_loss.backward()
            self.value_optim.step()
        else:
            value_loss = torch.tensor(0.).to(self.device)

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        """
        We update the target weights to match the current value function weights periodically
        Update target parameter after every n(args.target_update_interval) updates
        """
        if updates % self.target_update_interval == 0 and self.policy_type == "Deterministic":
            soft_update(self.critic_target, self.critic, self.tau)

        elif updates % self.target_update_interval == 0 and self.policy_type == "Gaussian":
            soft_update(self.value_target, self.value, self.tau)
        return value_loss.item(), qf1_loss.item(), qf2_loss.item(
        ), policy_loss.item(), alpha_loss.item(), alpha_logs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None,
                   value_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        if value_path is None:
            value_path = "models/sac_value_{}_{}".format(env_name, suffix)
        print('Saving models to {}, {} and {}'.format(actor_path, critic_path,
                                                      value_path))
        torch.save(self.value.state_dict(), value_path)
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path, value_path):
        print('Loading models from {}, {} and {}'.format(
            actor_path, critic_path, value_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
        if value_path is not None:
            self.value.load_state_dict(torch.load(value_path))
Exemple #3
0
class SAC(object):
    def __init__(self, num_inputs, action_space, variant):

        self.gamma = variant['gamma']
        self.tau = variant['tau']
        self.alpha = variant['alpha']
        self.policy_type = variant['policy_type']
        self.target_update_interval = variant['target_update_interval']
        self.automatic_entropy_tuning = variant['automatic_entropy_tuning']
        self.lr = variant.get("lr", 1e-3)

        self.device = torch.device("cuda" if variant['cuda'] else "cpu")
        self.hidden_size = variant.get('hidden_size', [128, 128])

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               self.hidden_size).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=self.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      self.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == 'Gaussian':
            if self.automatic_entropy_tuning:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=self.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         self.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=self.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              self.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=self.lr)

    def select_action(self, state, evaluate=False):

        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        #sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        # samle a batch of action and appropriate log_pi
        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            # alpha_tlogs = self.alpha.clone()
        else:
            alpha_loss = torch.tensor(0.0).to(self.device)

        if update % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item()

    def save_model(self,
                   env_nam,
                   suffix=".pkl",
                   actor_path=None,
                   critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')
        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)

        print("Saving models to {} and {}".format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    def load_model(self, actor_path, critic_path):
        print('loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.action_range = [action_space.low, action_space.high]

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu") 

        self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)


            self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        action = action.detach().cpu().numpy()[0]
        return self.rescale_action(action)

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/a_{}".format(suffix)
        if critic_path is None:
            critic_path = "models/c_{}".format(suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
    
    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
Exemple #5
0
class TD3Agent():
    def __init__(self, s_dim, a_dim, action_space, args):
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.action_space = action_space
        self.lr_pi = args.lr_pi
        self.lr_q = args.lr_q
        self.gamma = args.gamma
        self.tau = args.tau
        self.noise_std = args.noise_std
        self.noise_clip = args.noise_clip
        self.batch_size = args.batch_size
        self.policy_update_interval = args.policy_update_interval
        self.device = torch.device(args.device)
        self.policy_loss_log = torch.tensor(0.).to(self.device)

        self.policy = DeterministicPolicy(self.s_dim,
                                          self.a_dim,
                                          self.device,
                                          action_space=self.action_space).to(
                                              self.device)
        self.policy_target = DeterministicPolicy(
            self.s_dim,
            self.a_dim,
            self.device,
            action_space=self.action_space).to(self.device)
        self.Q1 = QFunction(self.s_dim, self.a_dim).to(self.device)
        self.Q1_target = QFunction(self.s_dim, self.a_dim).to(self.device)
        self.Q2 = QFunction(self.s_dim, self.a_dim).to(self.device)
        self.Q2_target = QFunction(self.s_dim, self.a_dim).to(self.device)
        self.hard_update_target()

        self.optimizer_pi = optim.Adam(self.policy.parameters(), lr=self.lr_pi)
        self.optimizer_q1 = optim.Adam(self.Q1.parameters(), lr=self.lr_q)
        self.optimizer_q2 = optim.Adam(self.Q2.parameters(), lr=self.lr_q)

    def hard_update_target(self):
        self.policy_target.load_state_dict(self.policy.state_dict())
        self.Q1_target.load_state_dict(self.Q1.state_dict())
        self.Q2_target.load_state_dict(self.Q2.state_dict())

    def soft_update_target(self):
        for param, param_target in zip(self.policy.parameters(),
                                       self.policy_target.parameters()):
            param_target.data.copy_(param.data * self.tau + param_target.data *
                                    (1 - self.tau))
        for param, param_target in zip(self.Q1.parameters(),
                                       self.Q1_target.parameters()):
            param_target.data.copy_(param.data * self.tau + param_target.data *
                                    (1 - self.tau))
        for param, param_target in zip(self.Q2.parameters(),
                                       self.Q2_target.parameters()):
            param_target.data.copy_(param.data * self.tau + param_target.data *
                                    (1 - self.tau))

    def choose_action(self, s):
        s = torch.from_numpy(s).to(self.device).float()
        return self.policy.sample(s).cpu().detach().numpy()

    def learn(self, memory, total_step):
        s, a, r, s_, done = memory.sample_batch(self.batch_size)
        s = torch.from_numpy(s).to(self.device)
        a = torch.from_numpy(a).to(self.device)
        r = torch.from_numpy(r).to(self.device).unsqueeze(dim=1)
        s_ = torch.from_numpy(s_).to(self.device)
        done = torch.from_numpy(done).to(self.device).unsqueeze(dim=1)

        noise = (torch.randn_like(a) * self.noise_std).clamp(
            -self.noise_clip, self.noise_clip)
        a_target_next = self.policy_target.sample(s_) + noise
        q1_next = self.Q1_target(s_, a_target_next)
        q2_next = self.Q2_target(s_, a_target_next)
        q_next_min = torch.min(q1_next, q2_next)
        q_loss_target = r + (1 - done) * self.gamma * q_next_min

        #update q1
        q1_loss_pred = self.Q1(s, a)
        q1_loss = F.mse_loss(q1_loss_pred, q_loss_target.detach()).mean()
        self.optimizer_q1.zero_grad()
        q1_loss.backward()
        self.optimizer_q1.step()

        #update q2
        q2_loss_pred = self.Q2(s, a)
        q2_loss = F.mse_loss(q2_loss_pred, q_loss_target.detach()).mean()
        self.optimizer_q2.zero_grad()
        q2_loss.backward()
        self.optimizer_q2.step()

        #delay upodate policy
        if total_step % self.policy_update_interval == 0:
            policy_loss = -self.Q1(s, self.policy.sample(s)).mean()
            self.optimizer_pi.zero_grad()
            policy_loss.backward()
            self.optimizer_pi.step()
            self.soft_update_target()

            self.policy_loss_log = policy_loss

        return q1_loss.item(), q2_loss.item(), self.policy_loss_log.item()

    def save_model(self,
                   env_name,
                   remarks='',
                   pi_path=None,
                   q1_path=None,
                   q2_path=None):
        if not os.path.exists('pretrained_models/'):
            os.mkdir('pretrained_models/')

        if pi_path == None:
            pi_path = 'pretrained_models/policy_{}_{}'.format(
                env_name, remarks)
        if q1_path == None:
            q1_path = 'pretrained_models/q1_{}_{}'.format(env_name, remarks)
        if q2_path == None:
            q2_path = 'pretrained_models/q2_{}_{}'.format(env_name, remarks)
        print('Saving model to {} , {} and {}'.format(pi_path, q1_path,
                                                      q2_path))
        torch.save(self.policy.state_dict(), pi_path)
        torch.save(self.Q1.state_dict(), q1_path)
        torch.save(self.Q2.state_dict(), q2_path)

    def load_model(self, pi_path, q1_path, q2_path):
        print('Loading models from {} , {} and {}'.format(
            pi_path, q1_path, q2_path))
        self.policy.load_state_dict(torch.load(pi_path))
        self.Q1.load_state_dict(torch.load(q1_path))
        self.Q2.load_state_dict(torch.load(q2_path))
Exemple #6
0
class DDPGAgent():
    def __init__(self, args, env_params):
        self.o_dim = env_params['o_dim']
        self.a_dim = env_params['a_dim']
        self.action_boundary = env_params['action_boundary']

        self.lr_a = args.lr_a
        self.lr_c = args.lr_c
        self.gamma = args.gamma
        self.tau = args.tau
        self.noise_eps = args.noise_eps
        self.batch_size = args.batch_size

        self.device = torch.device(args.device)

        self.actor = DeterministicPolicy(self.o_dim,
                                         self.a_dim).to(self.device)
        self.actor_tar = DeterministicPolicy(self.o_dim,
                                             self.a_dim).to(self.device)
        self.critic = QFunction(self.o_dim, self.a_dim).to(self.device)
        self.critic_tar = QFunction(self.o_dim, self.a_dim).to(self.device)

        self.optimizer_a = optim.Adam(self.actor.parameters(), lr=self.lr_a)
        self.optimizer_c = optim.Adam(self.critic.parameters(), lr=self.lr_c)

        self.hard_update()

    def hard_update(self):
        self.actor_tar.load_state_dict(self.actor.state_dict())
        self.critic_tar.load_state_dict(self.critic.state_dict())

    def soft_update(self):
        for params, params_tar in zip(self.actor.parameters(),
                                      self.actor_tar.parameters()):
            params_tar.data.copy_(self.tau * params.data +
                                  (1 - self.tau) * params_tar.data)
        for params, params_tar in zip(self.critic.parameters(),
                                      self.critic_tar.parameters()):
            params_tar.data.copy_(self.tau * params.data +
                                  (1 - self.tau) * params_tar.data)

    def choose_action(self, obs, is_evaluete=False):
        obs = torch.from_numpy(obs).float().to(self.device)
        with torch.no_grad():
            action = self.actor(obs)
        if not is_evaluete:
            action += torch.normal(torch.tensor(0.),
                                   torch.tensor(self.noise_eps))
        action = torch.clamp(action, -self.action_boundary,
                             self.action_boundary).cpu().detach().numpy()
        return action

    def rollout(self, env, memory, is_evaluate=False):
        total_reward = 0.
        obs = env.reset()
        done = False
        while not done:
            a = self.choose_action(obs, is_evaluate)
            obs_, r, done, info = env.step(a)

            memory.store(obs, a, r, obs_, done)

            total_reward += r
            obs = obs_
        return total_reward

    def update(self, memory):
        obs, a, r, obs_, done = memory.sample_batch(self.batch_size)

        obs = torch.from_numpy(obs).float().to(self.device)
        a = torch.from_numpy(a).float().to(self.device)
        r = torch.from_numpy(r).float().to(self.device)
        obs_ = torch.from_numpy(obs_).float().to(self.device)
        done = torch.from_numpy(done).float().to(self.device)

        with torch.no_grad():
            next_action_tar = self.actor_tar(obs_)
            next_q_tar = self.critic_tar(obs_, next_action_tar)
            critic_target = r + (1 - done) * self.gamma * next_q_tar
        critic_eval = self.critic(obs, a)
        loss_critic = F.mse_loss(critic_eval, critic_target.detach())
        self.optimizer_c.zero_grad()
        loss_critic.backward()
        self.optimizer_c.step()

        loss_actor = -self.critic(obs, self.actor(obs)).mean()
        self.optimizer_a.zero_grad()
        loss_actor.backward()
        self.optimizer_a.step()

        self.soft_update()

    def save_model(self, remark):
        if not os.path.exists('pretrained_model/'):
            os.mkdir('pretrained_model/')
        path = 'pretrained_model/{}.pt'.format(remark)
        print('Saving model to {}'.format(path))
        torch.save(self.actor.state_dict(), path)

    def load_model(self, remark):
        path = 'pretrained_model/{}.pt'.format(remark)
        print('Loading model from {}'.format(path))
        model = torch.load(path)
        self.actor.load_state_dict(model)
Exemple #7
0
class Agent(object):
    def __init__(self, num_inputs, action_space, args):
        self.args = args
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.alpha1 = args.alpha1

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)
        self.l = []
        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:

            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_model1(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index:index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param.to("cuda:0") + params.to("cuda:0"))
            index += params_length

    def update_parametersafter(self, memory, batch_size, updates, env, enco):
        '''
        Temporarily updates the parameters of the first agent.

        Parameters
        ----------
        memory : class 'replay_memory.ReplayMemory'
            
        batch_size : int
        
        updates : int
        
        env : 'gym.wrappers.time_limit.TimeLimit'
            The environment of interest
        enco : class
            The corresponding autoencoder
        '''
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)

            cat = torch.cat((next_state_batch, next_state_action), dim=-1)
            s = enco(cat)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)

            next_q_value = reward_batch + self.alpha * ll(
                cat, s) + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = (-min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optim.step()

        alpha_loss = torch.tensor(0.).to(self.device)
        alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    def update_parametersdeter(self, memory, batch_size, updates, env, enco):
        '''
        Updates the paratmeters of the second agent.

        Parameters
        ----------
        memory : class 'replay_memory.ReplayMemory'
            
        batch_size : int
        
        updates : int
        
        env : 'gym.wrappers.time_limit.TimeLimit'
            The environment of interest
        enco : class
            The corresponding autoencoder


        '''
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)

            cat = torch.cat((state_batch, action_batch), dim=-1)
            s = enco(cat)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            act, _, _ = self.policy.sample(state_batch)

            next_q_value = reward_batch - self.alpha1 * (
                (ll(cat, s) - torch.min(ll(cat, s))) /
                (torch.max(ll(cat, s)) - torch.min(ll(cat, s)))) * ll(
                    act, action_batch) + mask_batch * self.gamma * (
                        min_qf_next_target)  #refer to the paper
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = (
            -min_qf_pi).mean()  # J_{π_2} = 𝔼st∼D,εt∼N[− Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optim.step()

        alpha_loss = torch.tensor(0.).to(self.device)
        alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    def X(self,
          ss,
          a,
          memory,
          batch_size,
          updates,
          env,
          enco,
          Qdac,
          pidac,
          QTdac,
          args,
          tenco,
          normalization=False):
        '''
        Updates the parameters of the first agent (with the influence function and intrinsic rewards).

        Parameters
        ----------
        ss : numpy array
           Current state
            
        a : numpy array
           Action taken in ss
            
        memory : class 'replay_memory.ReplayMemory'
            
        batch_size : int
        
        updates : int
        
        env : 'gym.wrappers.time_limit.TimeLimit'
            The environment of interest
            
        enco : 
            The corresponding autoencoder
            
        Qdac : 
            The critic network of the second agent.

        pidac : 
           The policy network of the second agent.
            
        QTdac : 
           The target critic network of the second agent.
            
        args : 
           Hyperparameters determined by the user.
            
        tenco : 
           A virtual/proxy autoencoder used to calculate the frequency of (ss,a) w.r.t first agent's policy
        '''

        Qdac_optim = Adam(Qdac.parameters(), lr=args.lr)
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)

            cat = torch.cat((next_state_batch, next_state_action), dim=-1)
            s = enco(cat)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)

            if normalization:
                next_q_value = reward_batch + self.alpha * ll(
                    cat, s) / torch.max(
                        ll(cat,
                           s)) + mask_batch * self.gamma * (min_qf_next_target)
            else:
                next_q_value = reward_batch + self.alpha * ll(
                    cat, s) + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        #Update the proxy Qs of the DAC according to the

        #Qdac, pidac, QTdac
        with torch.no_grad():
            next_state_action, _, _ = pidac.sample(next_state_batch)

            qf1_next_target, qf2_next_target = QTdac(next_state_batch,
                                                     next_state_action)

            cat = torch.cat((state_batch, action_batch), dim=-1)
            s = tenco(cat)

            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            act, _, _ = pidac.sample(state_batch)

            next_q_value = reward_batch - self.alpha1 * (
                (ll(cat, s) - torch.min(ll(cat, s))) /
                (torch.max(ll(cat, s)) - torch.min(ll(cat, s)))) * ll(
                    act, action_batch) + mask_batch * self.gamma * (
                        min_qf_next_target)
        qf1, qf2 = Qdac(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)

        Qdac_optim.zero_grad()  #Update 2nd Agent's proxy networks
        qf1_loss.backward()
        Qdac_optim.step()

        Qdac_optim.zero_grad()
        qf2_loss.backward()
        Qdac_optim.step()

        #Find the value of F--the influence function
        pi_BAC, _, _ = self.policy.sample(state_batch)

        with torch.no_grad():
            next_state_action, _, _ = pidac.sample(next_state_batch)
            qf1_next_target, qf2_next_target = QTdac(next_state_batch,
                                                     next_state_action)

            cat = torch.cat((state_batch, pi_BAC), dim=-1)
            s = enco(cat)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            act, _, _ = pidac.sample(state_batch)

            next_q_value = reward_batch - self.alpha1 * (
                (ll(cat, s) - torch.min(ll(cat, s))) /
                (torch.max(ll(cat, s)) - torch.min(ll(cat, s)))) * ll(
                    act,
                    pi_BAC) + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = Qdac(state_batch, pi_BAC)

        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)

        minlossinf = torch.min(qf1_loss, qf2_loss)
        qf1_pi, qf2_pi = self.critic(state_batch, pi_BAC)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        policy_loss = (-min_qf_pi).mean()

        policy_loss += 0.1 * minlossinf  #Regulate the objective function of the first agent by adding F
        self.policy_optim.zero_grad()
        policy_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optim.step()

        alpha_loss = torch.tensor(0.).to(self.device)
        alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   enco,
                   suffix="",
                   actor_path=None,
                   critic_path=None,
                   enco_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/actor/IRLIA_actor_{}_{}".format(
                env_name, suffix)
        if critic_path is None:
            critic_path = "models/critic/IRLIA_critic_{}_{}".format(
                env_name, suffix)
        if enco_path is None:
            enco_path = "models/enco/IRLIA_enco_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
        torch.save(enco.state_dict(), enco_path)

    # Load model parameters
    def load_model(self, enco, actor_path, critic_path, enco_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
        if enco_path is not None:
            enco.load_state_dict(torch.load(enco_path))
Exemple #8
0
class SAC_fourier(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.Qapproximation = args.Qapproximation
        try:
            self.filter = args.filter
        except:
            self.filter = 'none'
        try:
            self.TDfilter = args.TDfilter
        except:
            self.TDfilter = 'none'

        if args.Qapproximation == 'fourier':
            self.critic = Qfourier(num_inputs,
                                   action_space.shape[0],
                                   256,
                                   action_space,
                                   gridsize=20).to(device=self.device)
            # target doesn't need to filter Q in high frequencies
            self.critic_target = Qfourier(num_inputs,
                                          action_space.shape[0],
                                          256,
                                          action_space,
                                          gridsize=20).to(device=self.device)

        if args.Qapproximation == 'byactiondim':
            self.critic = Qbyactiondim(num_inputs, action_space.shape[0], 256,
                                       8, 5,
                                       action_space).to(device=self.device)
            self.critic_target = Qbyactiondim(
                num_inputs, action_space.shape[0], 256, 8, 5,
                action_space).to(device=self.device)

#        self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        #        self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _, _ = self.policy.sample(state)
        else:
            _, _, action, _ = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        qf1, filterbw = self.critic(state_batch, action_batch)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _, std = self.policy.sample(
                next_state_batch)

            if self.Qapproximation == 'byactiondim':
                qf1_next_target = self.critic_target(next_state_batch,
                                                     next_state_action)

            if self.Qapproximation == 'fourier' and self.TDfilter == 'none':
                qf1_next_target = self.critic_target(next_state_batch,
                                                     next_state_action)

        if self.Qapproximation == 'fourier' and self.TDfilter == 'rec_inside':
            # with torch.no_grad():
            #     _, _, _, std = self.policy.sample(state_batch)
            qf1_next_target, filterpower = self.critic_target(
                next_state_batch,
                next_state_action,
                std=std,  # detached
                logprob=None,
                mu=None,
                filter=self.TDfilter,
                filterbw=filterbw)

        min_qf_next_target = qf1_next_target - self.alpha * next_state_log_pi
        next_q_value = reward_batch + mask_batch * self.gamma * (
            min_qf_next_target)

        # >>>> added filter regularization
        qf1_loss = F.mse_loss(qf1, next_q_value) + 1 * torch.mean(
            torch.pow(
                torch.log(filterbw / (1 - filterbw))  # (-inf,inf)
                ,
                2))
        # >>>> added filter regularization

        if self.Qapproximation == 'byactiondim':
            pi, log_pi, _, _ = self.policy.sample(state_batch)
            qf1_pi = self.critic(state_batch, pi)

        if self.Qapproximation == 'fourier':
            if self.filter == 'none':
                #                with torch.no_grad():
                #                    _, _, _, std = self.policy.sample(state_batch)
                pi, log_pi, _, std = self.policy.sample(state_batch)
                qf1_pi, _ = self.critic(state_batch, pi)
            if self.filter == 'rec_inside':
                with torch.no_grad():
                    _, _, _, std = self.policy.sample(state_batch)
                pi, log_pi, _, _ = self.policy.sample(state_batch)
                qf1_pi, _ = self.critic(
                    state_batch,
                    pi,
                    std=std,  # detached
                    logprob=None,
                    mu=None,
                    filter=self.filter,
                    cutoffXX=self.cutoffX)
            if self.filter == 'rec_outside':
                with torch.no_grad():
                    _, _, mu, std, log_pi_ = self.policy.sample_for_spectrogram(
                        state_batch)
                pi, log_pi, _, _, _ = self.policy.sample_for_spectrogram(
                    state_batch)
                qf1_pi, _ = self.critic(
                    state_batch,
                    pi,
                    std=std,  # detached
                    logprob=log_pi_,  # sum of logprobs w/o tanh correction
                    mu=mu,
                    filter=self.filter,
                    cutoffXX=self.cutoffX)

        min_qf_pi = qf1_pi

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs
            filterbw0_tlogs = torch.mean(torch.tensor(filterbw),
                                         0)[0]  # For TensorboardX logs
            filterbw1_tlogs = torch.mean(torch.tensor(filterbw),
                                         0)[1]  # For TensorboardX logs
            filterbw2_tlogs = torch.mean(torch.tensor(filterbw),
                                         0)[2]  # For TensorboardX logs
            filterpower_tlogs = torch.mean(
                filterpower)  # For TensorboardX logs
            filterpower_tlogs = torch.mean(
                filterpower)  # For TensorboardX logs
            min_qf_pi_tlogs = torch.mean(min_qf_pi)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

#        print('{:5.2f} {:5.2f} {:5.2f}'.format(std[:,0].std().item(),std[:,1].std().item(),std[:,2].std().item()))
        return qf1_loss.item(), 0, policy_loss.item(), alpha_loss.item(), alpha_tlogs.item(), \
            std.mean().item(), \
            filterbw0_tlogs.item(), filterbw1_tlogs.item(), filterbw2_tlogs.item(), \
            filterpower_tlogs.item(), \
            min_qf_pi_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "./models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "./models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))

    def spectrum(self, memory, batch_size, action_space, To=2, modes=10):

        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch, \
            log_prob_batch, std_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
        std_batch = torch.FloatTensor(std_batch).to(self.device).squeeze(1)
        log_prob_batch = torch.FloatTensor(log_prob_batch).to(
            self.device).squeeze(1)
        prob_batch = torch.exp(log_prob_batch)

        with torch.no_grad():

            qf1 = self.critic.spectrum(state_batch, action_batch, std_batch,
                                       prob_batch, action_space, To, modes)
Exemple #9
0
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        # Q network, which yields a certain value for (a_t | s_t) pair
        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        # a sort of a replica - since, due to Bellman recursive definition, Q network learns from itself- and its unstbale
        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device)
        # the start point is same weights in both networks.
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # todo: crunch on this automatic alpha update
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            # instanciating of policy - given a state it produces probabilities for actions
            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            # todo: what's difference between deterministic to Gaussian
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        action = action.detach().cpu().numpy()
        return action[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        # predict your own stuff - note that critic_target is never back-propagated
        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)

        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1, qf2 = self.critic(state_batch, action_batch)
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        actBatch, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, actBatch)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        # once in a few cycles of learning we update the target net
        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
Exemple #10
0
    shuffle=True,  # 要不要打乱数据 (打乱比较好)
)

action_space = spaces.Box(low=-np.array([0.5, 2.0]),
                          high=+np.array([0.5, 2.0]),
                          shape=(2, ),
                          dtype=np.float32)
net = DeterministicPolicy(6, 2, 256, action_space).to('cuda')

optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

for epoch in range(1000):  # 训练所有!整套!数据 3 次
    for step, (batch_x,
               batch_u) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
        # print(batch_x.shape)
        prediction = net(batch_x.cuda())  # input x and predict based on x

        loss = loss_func(prediction,
                         batch_u.cuda())  # must be (1. nn output, 2. target)

        optimizer.zero_grad()  # clear gradients for next train
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients
        if (step == 0):
            prediction_eval = net(x_eval.cuda())
            loss_eval = loss_func(prediction_eval, u_eval.cuda())
            print('Epoch: ', epoch, '| Step: ', step, '| loss : ', loss_eval)
    if (epoch % 100 == 0):
        torch.save(net.state_dict(), "models/pretrain")
Exemple #11
0
class BAC(object):
    def __init__(self, num_inputs, action_space, args):
        self.args = args
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)
        self.l = []
        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:

            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_model1(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index:index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param.to("cuda:0") + params.to("cuda:0"))
            index += params_length

    def update_parametersbefore(self, memory, batch_size, updates):
        '''
        Temporarily updates the parameters of the agent w.r.t. external rewards only.

        Parameters
        ----------
        memory : class 'replay_memory.ReplayMemory'
            
        batch_size : int
        
        updates : int
        
        '''
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, _, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = (-min_qf_pi).mean()  # Jπ = 𝔼st∼D,εt∼N[− Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    def update_parametersafter(self, memory, batch_size, updates, env, enco):
        '''
        Updates the parameters of the agent w.r.t. external and intrinsic rewards.

        Parameters
        ----------
        memory : class 'replay_memory.ReplayMemory'
            
        batch_size : int
        
        updates : int
        
        env : 'gym.wrappers.time_limit.TimeLimit'
            The environment of interest
            
        enco : class
            The corresponding autoencoder
        '''
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)

            cat = torch.cat((next_state_batch, next_state_action), dim=-1)
            s = enco(cat)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)

            next_q_value = reward_batch + self.alpha * ll(cat, s) / torch.max(
                ll(cat, s)) + mask_batch * self.gamma * (min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = (-min_qf_pi).mean()

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        if updates % 3 == 0:
            self.policy_optim.zero_grad()
            policy_loss.backward()

            self.policy_optim.step()

        alpha_loss = torch.tensor(0.).to(self.device)
        alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   enco,
                   suffix="",
                   actor_path=None,
                   critic_path=None,
                   enco_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/actor/sac_actor_{}_{}".format(
                env_name, suffix)
        if critic_path is None:
            critic_path = "models/critic/sac_critic_{}_{}".format(
                env_name, suffix)
        if enco_path is None:
            enco_path = "models/enco/sac_enco_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
        torch.save(enco.state_dict(), enco_path)

    # Load model parameters
    def load_model(self, enco, actor_path, critic_path, enco_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
        if enco_path is not None:
            enco.load_state_dict(torch.load(enco_path))
Exemple #12
0
class SAC_FORK(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha

        self.policy_type = args.policy_type
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        self.sysmodel = SysModel(num_inputs, action_space.shape[0],
                                 args.sys_hidden_size,
                                 args.sys_hidden_size).to(self.device)
        self.sysmodel_optimizer = Adam(self.sysmodel.parameters(), lr=args.lr)

        self.obs_upper_bound = 0  #state space upper bound
        self.obs_lower_bound = 0  #state space lower bound

        self.sysr = Sys_R(num_inputs, action_space.shape[0],
                          args.sysr_hidden_size,
                          args.sysr_hidden_size).to(self.device)
        self.sysr_optimizer = torch.optim.Adam(self.sysr.parameters(),
                                               lr=args.lr)

        self.sys_threshold = args.sys_threshold
        self.sys_weight = args.sys_weight
        self.sysmodel_loss = 0
        self.sysr_loss = 0

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        predict_next_state = self.sysmodel(state_batch, action_batch)
        predict_next_state = predict_next_state.clamp(self.obs_lower_bound,
                                                      self.obs_upper_bound)
        sysmodel_loss = F.smooth_l1_loss(predict_next_state,
                                         next_state_batch.detach())
        self.sysmodel_optimizer.zero_grad()
        sysmodel_loss.backward()
        self.sysmodel_optimizer.step()
        self.sysmodel_loss = sysmodel_loss.item()

        predict_reward = self.sysr(state_batch, next_state_batch, action_batch)
        sysr_loss = F.mse_loss(predict_reward, reward_batch.detach())
        self.sysr_optimizer.zero_grad()
        sysr_loss.backward()
        self.sysr_optimizer.step()
        self.sysr_loss = sysr_loss.item()

        s_flag = 1 if sysmodel_loss.item() < self.sys_threshold else 0

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        if s_flag == 1 and self.sys_weight != 0:
            p_next_state = self.sysmodel(state_batch, pi)
            p_next_state = p_next_state.clamp(self.obs_lower_bound,
                                              self.obs_upper_bound)
            p_next_r = self.sysr(state_batch, p_next_state.detach(), pi)

            pi2, log_pi2, _ = self.policy.sample(p_next_state.detach())
            p_next_state2 = self.sysmodel(p_next_state, pi2)
            p_next_state2 = p_next_state2.clamp(self.obs_lower_bound,
                                                self.obs_upper_bound)
            p_next_r2 = self.sysr(p_next_state.detach(),
                                  p_next_state2.detach(), pi2)

            pi3, log_pi3, _ = self.policy.sample(p_next_state2.detach())
            qf3_pi, qf4_pi = self.critic(p_next_state2.detach(), pi3)
            min_qf_pi2 = torch.min(qf3_pi, qf4_pi)

            #sys_loss = (-p_next_r -self.gamma * p_next_r2 + self.gamma ** 2 * ((self.alpha * log_pi3) - min_qf_pi2)).mean()
            sys_loss = (-p_next_r + self.alpha * log_pi - self.gamma *
                        (p_next_r2 - self.alpha * log_pi2) + self.gamma**2 *
                        ((self.alpha * log_pi3) - min_qf_pi2)).mean()
            policy_loss += self.sys_weight * sys_loss
            self.update_sys += 1

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optim.state_dict(),
                   filename + "_critic_optimizer")

        torch.save(self.policy.state_dict(), filename + "_actor")
        torch.save(self.policy_optim.state_dict(),
                   filename + "_actor_optimizer")

        torch.save(self.sysmodel.state_dict(), filename + "_sysmodel")
        torch.save(self.sysmodel_optimizer.state_dict(),
                   filename + "_sysmodel_optimizer")

        torch.save(self.sysr.state_dict(), filename + "_reward_model")
        torch.save(self.sysr_optimizer.state_dict(),
                   filename + "_reward_model_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
        self.critic_optim.load_state_dict(
            torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.policy.load_state_dict(torch.load(filename + "_actor.pth"))
        self.policy_optim.load_state_dict(
            torch.load(filename + "_actor_optimizer"))

        self.sysmodel.load_state_dict(torch.load(filename + "_sysmodel.pth"))
        relf.sysmodel_optimizer.load_state_dict(
            torch.load(filename + "_sysmodel_optimizer"))

        self.sysr.load_state_dict(torch.load(filename + "_reward_model.pth"))
        relf.sysr_optimizer.load_state_dict(
            torch.load(filename + "_reward_model_optimizer"))
Exemple #13
0
class ADVSAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.adv_epsilon = args.adv_epsilon
        self.adv_lambda = args.adv_lambda

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        # get adversarial perturbations for qf1 and qf2 respectively
        state_batch.requires_grad = True
        qf1, qf2 = self.critic_target(state_batch, action_batch)
        adv_qf1_loss = qf1.mean()
        self.critic_optim.zero_grad()
        adv_qf1_loss.backward(retain_graph=True)
        adv_perturb_1 = F.normalize(state_batch.grad.data)

        state_batch.requires_grad = True
        adv_qf2_loss = qf2.mean()
        self.critic_optim.zero_grad()
        adv_qf2_loss.backward(retain_graph=True)
        adv_perturb_2 = F.normalize(state_batch.grad.data)

        # get Q1 and Q2 adversarial estimation
        adv_state_1 = state_batch - self.adv_epsilon * adv_perturb_1
        adv_state_2 = state_batch - self.adv_epsilon * adv_perturb_2
        adv_qf1, _ = self.critic_target(adv_state_1, action_batch)
        _, adv_qf2 = self.critic_target(adv_state_2, action_batch)
        adv_error_1 = torch.clamp(qf1 - adv_qf1, 0.0, 10000.0)
        adv_error_2 = torch.clamp(qf1 - adv_qf2, 0.0, 10000.0)
        next_q_value_1 = next_q_value - self.adv_lambda * adv_error_1
        next_q_value_2 = next_q_value - self.adv_lambda * adv_error_2

        qff1, qff2 = self.critic(state_batch, action_batch)
        qf1_loss = F.mse_loss(qff1, next_q_value_1.detach())
        qf2_loss = F.mse_loss(qff2, next_q_value_2.detach())

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
        
        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item(), adv_error_1.mean().item(), adv_error_2.mean().item()


    # Save model parameters
    def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('./train/'):
            os.makedirs('./train/')
        if not os.path.exists('./train/{}/'.format(env_name)):
            os.makedirs('./train/{}/'.format(env_name))

        if actor_path is None:
            actor_path = "./train/{}/sac_actor_{}_{}.pth".format(env_name, env_name, suffix)
        if critic_path is None:
            critic_path = "./train/{}/sac_critic_{}_{}.pth".format(env_name, env_name, suffix)
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
    
    # Load model parameters
    def load_model(self, env_name, seed):
        actor_path = './train/{}/sac_actor_{}_{}.pth'.format(env_name, env_name, seed)
        critic_path = './train/{}/sac_critic_{}_{}.pth'.format(env_name, env_name, seed)
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
Exemple #14
0
class SAC(object):
    def __init__(self,
                 num_inputs,
                 action_space,
                 args,
                 process_obs=None,
                 opt_level='O1'):

        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.device = torch.device("cuda" if args.cuda else "cpu")
        self.dtype = torch.float

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.process_obs = process_obs.to(self.device).to(self.dtype)
        self.critic = QNetwork(num_inputs, action_space.shape[0],
                               args.hidden_size).to(device=self.device).to(
                                   self.dtype)
        self.critic_optim = Adam(list(self.critic.parameters()) +
                                 list(process_obs.parameters()),
                                 lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(self.device).to(
                                          self.dtype)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(self.device)).item()
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device,
                                             dtype=self.dtype)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size, action_space).to(
                                             self.device).to(self.dtype)
            self.policy_optim = Adam(list(self.policy.parameters()) +
                                     list(process_obs.parameters()),
                                     lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(
                num_inputs, action_space.shape[0], args.hidden_size,
                action_space).to(self.device).to(self.dtype)
            self.policy_optim = Adam(list(self.policy.parameters()) +
                                     list(process_obs.parameters()),
                                     lr=args.lr)

        if opt_level is not None:
            model, optimizer = amp.initialize([
                self.policy, self.process_obs, self.critic, self.critic_target
            ], [self.policy_optim, self.critic_optim],
                                              opt_level=opt_level)

    def select_action(self, obs, evaluate=False):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0).to(
                self.dtype)
            state = self.process_obs(obs)
            if evaluate is False:
                action, _, _ = self.policy.sample(state)
            else:
                _, _, action = self.policy.sample(state)
            action = action.detach().cpu().numpy()[0]
        return action

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        obs_batch, action_batch, reward_batch, next_obs_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        obs_batch = torch.FloatTensor(obs_batch).to(self.device).to(self.dtype)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device).to(
            self.dtype)
        action_batch = torch.FloatTensor(action_batch).to(self.device).to(
            self.dtype)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1).to(self.dtype)
        mask_batch = torch.FloatTensor(mask_batch).to(
            self.device).unsqueeze(1).to(self.dtype)

        state_batch = self.process_obs(obs_batch)
        with torch.no_grad():
            next_state_batch = self.process_obs(next_obs_batch)
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)
        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        assert torch.isfinite(qf_loss).all()
        with amp.scale_loss(qf_loss, self.critic_optim) as qf_loss:
            qf_loss.backward()
        self.critic_optim.step()

        state_batch = self.process_obs(obs_batch)
        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch.detach(), pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.policy_optim.zero_grad()
        assert torch.isfinite(policy_loss).all()
        with amp.scale_loss(policy_loss, self.policy_optim) as policy_loss:
            policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device).to(self.dtype)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   actor_path=None,
                   critic_path=None,
                   process_obs_path=None):
        logger.debug(
            f'saving models to {actor_path} and {critic_path} and {process_obs_path}'
        )
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)
        torch.save(self.process_obs.state_dict(), process_obs_path)

    # Load model parameters
    def load_model(self,
                   actor_path=None,
                   critic_path=None,
                   process_obs_path=None):
        logger.info(
            f'Loading models from {actor_path} and {critic_path} and {process_obs_path}'
        )
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
        if process_obs_path is not None:
            self.process_obs.load_state_dict(torch.load(process_obs_path))
Exemple #15
0
class SAC(object):
    def __init__(self, num_inputs, action_space, args):

        self.gamma = args.gamma  #γ
        self.tau = args.tau  #τ
        self.alpha = args.alpha  #α

        self.policy_type = args.policy  #策略类型,高斯随机策略、确定性策略
        self.target_update_interval = args.target_update_interval  #target network更新间隔
        self.automatic_entropy_tuning = args.automatic_entropy_tuning  #自动调熵

        self.device = torch.device("cuda" if args.cuda else "cpu")

        self.critic = QNetwork(num_inputs,
                               action_space.shape[0], args.hidden_size).to(
                                   device=self.device)  #Critic Network,Q网络
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        self.critic_target = QNetwork(num_inputs, action_space.shape[0],
                                      args.hidden_size).to(
                                          self.device)  #Target Q Network
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning == True:
                self.target_entropy = -torch.prod(
                    torch.Tensor(action_space.shape).to(
                        self.device)).item()  #torch.prod(input) : 返回所有元素的乘积
                self.log_alpha = torch.zeros(1,
                                             requires_grad=True,
                                             device=self.device)
                self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

            self.policy = GaussianPolicy(num_inputs, action_space.shape[0],
                                         args.hidden_size,
                                         action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = DeterministicPolicy(num_inputs,
                                              action_space.shape[0],
                                              args.hidden_size,
                                              action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(
                state)  #action, log_prob, torch.tanh(mean)
        else:
            _, _, action = self.policy.sample(
                state)  #action, log_prob, torch.tanh(mean)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(
            self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(
                next_state_batch)  #下一个状态下采取的动作,下一个状态的log概率
            qf1_next_target, qf2_next_target = self.critic_target(
                next_state_batch, next_state_action)  #Q1目标值和Q2目标值
            min_qf_next_target = torch.min(
                qf1_next_target,
                qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (
                min_qf_next_target)  #r(st,at) + γ(𝔼st+1~p[V(st+1)]))

        qf1, qf2 = self.critic(
            state_batch, action_batch
        )  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(
            qf1, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2],下一状态的值是target network算出来的
        qf2_loss = F.mse_loss(
            qf2, next_q_value
        )  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)  #动作,动作的对数概率

        qf1_pi, qf2_pi = self.critic(state_batch, pi)  #动作的Q值
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean(
        )  # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()
                           ).mean()  #E[-αlogπ(at|st)-αH]

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if updates % self.target_update_interval == 0:  #每隔几步更新target network
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
        ), alpha_loss.item(), alpha_tlogs.item()

    # Save model parameters
    def save_model(self,
                   env_name,
                   suffix="",
                   actor_path=None,
                   critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    # Load model parameters
    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))