Example #1
0
    def __init__(
            self,
            env,
            state_dim,
            action_dim,
            max_action,
            batch_size=256,
            discount=0.99,
            tau=0.005,
            max_timesteps=3e6,
            n_repeat=4,
            actor_lr=3e-4,
            alpha_start=0.7,
            alpha_end=0.9,
            device=torch.device('cuda'),
    ):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.total_it = 0

        self.device = device
        self.actor_lr = actor_lr  # here is actor lr is not the real actor learning rate

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.actor_lr)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=3e-4)

        cem_sigma = 1e-2
        cem_clip = 0.5 * max_action
        self.searcher = Searcher(action_dim,
                                 max_action,
                                 device=device,
                                 sigma_init=cem_sigma,
                                 clip=cem_clip,
                                 batch_size=batch_size)
        self.action_dim = float(action_dim)
        self.log_freq = 200
        self.third_loss_bound = alpha_start
        self.third_loss_bound_end = alpha_end
        self.max_timesteps = max_timesteps

        self.max_iter_steps = n_repeat
        self.cem_loss_coef = 1.0 / float(self.action_dim)
        self.selection_action_coef = 1.0
Example #2
0
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            batch_size=256,
            discount=0.99,
            max_timesteps=3e6,
            actor_lr=3e-4,
            critic_lr=3e-4,
            loss_decay=0.95,
            log_freq=200,
            actor_lr_ratio=1.0,
            device=torch.device('cuda'),
    ):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.max_action = max_action
        self.discount = discount
        self.total_it = 0

        self.device = device
        self.actor_lr = actor_lr  # here is actor lr is not the real actor learning rate
        self.critic_lr = critic_lr
        self.loss_decay = loss_decay
        self.actor_lr_ratio = actor_lr_ratio

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.actor_lr)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr)

        cem_sigma = 1e-2
        cem_clip = 0.5 * max_action
        self.cem_clip_init = cem_clip
        self.cem_clip = cem_clip
        self.searcher = Searcher(action_dim,
                                 max_action,
                                 device=device,
                                 sigma_init=cem_sigma,
                                 clip=cem_clip,
                                 batch_size=batch_size)
        self.action_dim = float(action_dim)
        self.log_freq = log_freq
        self.max_timesteps = max_timesteps
        self.cem_loss_coef = 1.0 / float(self.action_dim)
        self.selection_action_coef = 1.0
Example #3
0
class GRAC():
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            batch_size=256,
            discount=0.99,
            max_timesteps=3e6,
            actor_lr=3e-4,
            critic_lr=3e-4,
            loss_decay=0.95,
            log_freq=200,
            actor_lr_ratio=1.0,
            device=torch.device('cuda'),
    ):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.max_action = max_action
        self.discount = discount
        self.total_it = 0

        self.device = device
        self.actor_lr = actor_lr  # here is actor lr is not the real actor learning rate
        self.critic_lr = critic_lr
        self.loss_decay = loss_decay
        self.actor_lr_ratio = actor_lr_ratio

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.actor_lr)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr)

        cem_sigma = 1e-2
        cem_clip = 0.5 * max_action
        self.cem_clip_init = cem_clip
        self.cem_clip = cem_clip
        self.searcher = Searcher(action_dim,
                                 max_action,
                                 device=device,
                                 sigma_init=cem_sigma,
                                 clip=cem_clip,
                                 batch_size=batch_size)
        self.action_dim = float(action_dim)
        self.log_freq = log_freq
        self.max_timesteps = max_timesteps
        self.cem_loss_coef = 1.0 / float(self.action_dim)
        self.selection_action_coef = 1.0

    def select_action(self, state, writer=None, test=False):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        if test is False:
            with torch.no_grad():
                action = self.actor(state)
                ceof = self.selection_action_coef - min(
                    self.selection_action_coef - 0.05,
                    float(self.total_it) * 10.0 / float(self.max_timesteps))
                if np.random.uniform(0, 1) < ceof:
                    better_action, _ = self.searcher.search(state,
                                                            action,
                                                            self.critic.Q2,
                                                            batch_size=1,
                                                            clip=self.cem_clip)

                    Q1, Q2 = self.critic(state, action)
                    Q = torch.min(Q1, Q2)

                    better_Q1, better_Q2 = self.critic(state, better_action)
                    better_Q = torch.min(better_Q1, better_Q2)

                    action_index = (Q > better_Q).squeeze()
                    better_action[action_index] = action[action_index]
                else:
                    better_action = action
            return better_action.cpu().data.numpy().flatten()

        else:
            _, _, action, _ = self.actor.forward_all(state)
            return action.cpu().data.numpy().flatten()

    def lr_scheduler(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        return optimizer

    def update_critic(self, critic_loss):
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(),
                   filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(),
                   filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(
            torch.load(filename + "_critic_optimizer"))
        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(
            torch.load(filename + "_actor_optimizer"))

    def train(self,
              replay_buffer,
              batch_size=100,
              writer=None,
              reward_range=20.0,
              reward_max=0,
              episode_step_max=100,
              reward_min=0,
              episode_step_min=1):
        self.total_it += 1
        log_it = (self.total_it % self.log_freq == 0)
        ratio_it = max(1.0 - self.total_it / float(self.max_timesteps), 0.1)
        if log_it:
            writer.add_scalar('train_critic/ratio_it', ratio_it, self.total_it)
        self.cem_clip = self.cem_clip_init * ratio_it
        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(
            batch_size)
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            next_action = (self.actor(next_state)).clamp(
                -self.max_action, self.max_action)
            better_next_action, _ = self.searcher.search(next_state,
                                                         next_action,
                                                         self.critic.Q2,
                                                         clip=self.cem_clip)

            target_Q1, target_Q2 = self.critic(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)

            better_target_Q1, better_target_Q2 = self.critic(
                next_state, better_next_action)
            better_target_Q = torch.min(better_target_Q1, better_target_Q2)

            action_index = (target_Q > better_target_Q).squeeze()
            better_next_action[action_index] = next_action[action_index]
            better_target_Q1, better_target_Q2 = self.critic(
                next_state, better_next_action)

            better_target_Q = torch.max(better_target_Q, target_Q)

            target_Q1 = better_target_Q1
            target_Q2 = better_target_Q2

            Q_max = reward_max / (1 - self.discount) * (
                1 - self.discount**int(episode_step_max))
            target_Q1[target_Q1 > Q_max] = Q_max
            target_Q2[target_Q2 > Q_max] = Q_max
            better_target_Q[better_target_Q > Q_max] = Q_max
            if reward_min >= 0:
                Q_min = reward_min / (1 - self.discount) * (
                    1 - self.discount**int(episode_step_min))
            else:
                Q_min = reward_min / (1 - self.discount) * (
                    1 - self.discount**int(episode_step_max))
            target_Q1[target_Q1 < Q_min] = Q_min
            target_Q2[target_Q2 < Q_min] = Q_min
            better_target_Q[better_target_Q < Q_min] = Q_min

            target_Q_final = reward + not_done * self.discount * better_target_Q
            target_Q_final[target_Q_final > Q_max] = Q_max
            target_Q_final[target_Q_final < Q_min] = Q_min
            next_action = better_next_action

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q_final) + F.mse_loss(
            current_Q2, target_Q_final)
        self.update_critic(critic_loss)

        current_Q1_, current_Q2_ = self.critic(state, action)
        target_Q1_, target_Q2_ = self.critic(next_state, next_action)
        critic_loss2_1 = F.mse_loss(current_Q1_, target_Q_final) + F.mse_loss(
            current_Q2_, target_Q_final)
        critic_loss2_2 = F.mse_loss(target_Q1_, target_Q1) + F.mse_loss(
            target_Q2_, target_Q2)
        weight1 = critic_loss2_1.item()
        weight2 = critic_loss2_2.item()
        weight_loss = (math.sqrt(weight1) + 1.) / (math.sqrt(weight2) + 1.)
        critic_loss3 = critic_loss2_1 + critic_loss2_2 * weight_loss
        self.update_critic(critic_loss3)
        init_critic_loss3 = critic_loss3.clone()
        ratio = 0.0
        max_step = 0

        idi = 0
        while True:
            idi = idi + 1
            current_Q1_, current_Q2_ = self.critic(state, action)
            target_Q1_, target_Q2_ = self.critic(next_state, next_action)
            critic_loss3_1 = F.mse_loss(current_Q1_,
                                        target_Q_final) + F.mse_loss(
                                            current_Q2_, target_Q_final)
            critic_loss3_2 = F.mse_loss(target_Q1_, target_Q1) + F.mse_loss(
                target_Q2_, target_Q2)
            critic_loss3 = critic_loss3_1 + critic_loss3_2 * weight_loss
            self.update_critic(critic_loss3)
            if critic_loss3_1 < critic_loss * self.loss_decay and critic_loss3_1 < critic_loss2_1 * self.loss_decay and torch.sqrt(
                    critic_loss3_2) < torch.max(
                        torch.mean(torch.abs(better_target_Q)) * 0.01,
                        torch.mean(torch.abs(reward))) and torch.sqrt(
                            critic_loss3_2) < reward_max * 1.0:
                break
            if idi > 50:
                break
        critic_loss = F.mse_loss(current_Q1, target_Q_final) + F.mse_loss(
            current_Q2, target_Q_final)
        weights_actor_lr = critic_loss.detach()

        if log_it:
            writer.add_scalar('train_critic/weight_loss', weight_loss,
                              self.total_it)
            writer.add_scalar('train_critic/third_loss_num', idi,
                              self.total_it)
            writer.add_scalar('train_critic/Q_max', Q_max, self.total_it)
            writer.add_scalar('train_critic/episode_step_max',
                              episode_step_max, self.total_it)
            writer.add_scalar('train_critic/Q_min', Q_min, self.total_it)
            writer.add_scalar('train_critic/cem_clip', self.cem_clip,
                              self.total_it)
            #writer.add_scalar('train_critic/Q_min_mean', torch.mean(Q_min), self.total_it)
            #writer.add_scalar('train_critic/Q_min_min', torch.min(Q_min), self.total_it)
            writer.add_scalar('train_critic/episode_step_min',
                              episode_step_min, self.total_it)
        if log_it:
            writer.add_scalar('train_loss/loss2_1', critic_loss2_1,
                              self.total_it)
            writer.add_scalar('train_loss/loss2_2', critic_loss2_2,
                              self.total_it)
            writer.add_scalar('train_loss/loss3_1_r',
                              critic_loss3_1 / critic_loss2_1, self.total_it)
            writer.add_scalar('train_loss/loss3_2_r',
                              critic_loss3_2 / critic_loss2_2, self.total_it)
            writer.add_scalar('train_loss/loss3_1_r_loss',
                              critic_loss3_1 / critic_loss, self.total_it)
            writer.add_scalar('train_loss/sqrt_critic_loss3_2',
                              torch.sqrt(critic_loss3_2), self.total_it)
            writer.add_scalar('train_loss/targetQ_condition',
                              torch.mean(torch.abs(better_target_Q)) * 0.01,
                              self.total_it)
            writer.add_scalar('train_loss/reward_condition',
                              torch.mean(torch.abs(reward)), self.total_it)
            writer.add_scalar('train_loss/max_reward', reward_max,
                              self.total_it)
            #writer.add_scalar('train_loss/min_reward',reward_min,self.total_it)
        if self.total_it % 1 == 0:
            lr_tmp = self.actor_lr / (float(weights_actor_lr) +
                                      1.0) * self.actor_lr_ratio
            self.actor_optimizer = self.lr_scheduler(self.actor_optimizer,
                                                     lr_tmp)

            # Compute actor loss
            actor_action, log_prob, action_mean, action_sigma = self.actor.forward_all(
                state)
            q_actor_action = self.critic.Q1(state, actor_action)
            m = Normal(action_mean, action_sigma)

            better_action, _ = self.searcher.search(state,
                                                    actor_action,
                                                    self.critic.Q1,
                                                    batch_size=batch_size,
                                                    clip=self.cem_clip)  #####
            q_better_action = self.critic.Q1(state, better_action)
            log_prob_better_action = m.log_prob(better_action).sum(
                1, keepdim=True)

            adv = (q_better_action - q_actor_action).detach()
            adv = torch.max(adv, torch.zeros_like(adv))
            cem_loss = log_prob_better_action * torch.min(
                reward_range * torch.ones_like(adv) * ratio_it, adv)
            actor_loss = -(cem_loss * self.cem_loss_coef +
                           q_actor_action).mean()

            # Optimize the actor
            Q_before_update = self.critic.Q1(state, actor_action)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            self.actor_optimizer = self.lr_scheduler(self.actor_optimizer,
                                                     self.actor_lr)

        if log_it:
            with torch.no_grad():
                writer.add_scalar('train_critic/third_loss_num', idi,
                                  self.total_it)
                writer.add_scalar('train_critic/critic_loss', critic_loss,
                                  self.total_it)
                writer.add_scalar('train_critic/critic_loss3', critic_loss3,
                                  self.total_it)

                target_Q1_Q2_diff = target_Q1 - target_Q2
                writer.add_scalar('q_diff/target_Q1_Q2_diff_max',
                                  target_Q1_Q2_diff.max(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_min',
                                  target_Q1_Q2_diff.min(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_mean',
                                  target_Q1_Q2_diff.mean(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_abs_mean',
                                  target_Q1_Q2_diff.abs().mean(),
                                  self.total_it)

                current_Q1_Q2_diff = current_Q1 - current_Q2
                writer.add_scalar('q_diff/current_Q1_Q2_diff_max',
                                  current_Q1_Q2_diff.max(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_min',
                                  current_Q1_Q2_diff.min(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_mean',
                                  current_Q1_Q2_diff.mean(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_abs_mean',
                                  current_Q1_Q2_diff.abs().mean(),
                                  self.total_it)

                #target_Q1
                writer.add_scalar('train_critic/target_Q1/mean',
                                  torch.mean(target_Q1), self.total_it)
                writer.add_scalar('train_critic/target_Q1/max',
                                  target_Q1.max(), self.total_it)
                writer.add_scalar('train_critic/target_Q1/min',
                                  target_Q1.min(), self.total_it)
                writer.add_scalar('train_critic/target_Q1/std',
                                  torch.std(target_Q1), self.total_it)
                #target_Q2
                writer.add_scalar('train_critic/target_Q2/mean',
                                  torch.mean(target_Q2), self.total_it)

                #current_Q1
                writer.add_scalar('train_critic/current_Q1/mean',
                                  current_Q1.mean(), self.total_it)
                writer.add_scalar('train_critic/current_Q1/std',
                                  torch.std(current_Q1), self.total_it)
                writer.add_scalar('train_critic/current_Q1/max',
                                  current_Q1.max(), self.total_it)
                writer.add_scalar('train_critic/current_Q1/min',
                                  current_Q1.min(), self.total_it)

                # current_Q2
                writer.add_scalar('train_critic/current_Q2/mean',
                                  current_Q2.mean(), self.total_it)

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(),
                   filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(),
                   filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(
            torch.load(filename + "_critic_optimizer"))
        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(
            torch.load(filename + "_actor_optimizer"))
Example #4
0
class GRAC():
    def __init__(
            self,
            env,
            state_dim,
            action_dim,
            max_action,
            batch_size=256,
            discount=0.99,
            tau=0.005,
            max_timesteps=3e6,
            n_repeat=4,
            actor_lr=3e-4,
            alpha_start=0.7,
            alpha_end=0.9,
            device=torch.device('cuda'),
    ):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.total_it = 0

        self.device = device
        self.actor_lr = actor_lr  # here is actor lr is not the real actor learning rate

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.actor_lr)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=3e-4)

        cem_sigma = 1e-2
        cem_clip = 0.5 * max_action
        self.searcher = Searcher(action_dim,
                                 max_action,
                                 device=device,
                                 sigma_init=cem_sigma,
                                 clip=cem_clip,
                                 batch_size=batch_size)
        self.action_dim = float(action_dim)
        self.log_freq = 200
        self.third_loss_bound = alpha_start
        self.third_loss_bound_end = alpha_end
        self.max_timesteps = max_timesteps

        self.max_iter_steps = n_repeat
        self.cem_loss_coef = 1.0 / float(self.action_dim)
        self.selection_action_coef = 1.0

    def select_action(self, state, writer=None, test=False):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        if test is False:
            with torch.no_grad():
                action = self.actor(state)
                ceof = self.selection_action_coef - min(
                    self.selection_action_coef - 0.05,
                    float(self.total_it) * 10.0 / float(self.max_timesteps))
                if np.random.uniform(0, 1) < ceof:
                    better_action = self.searcher.search(state,
                                                         action,
                                                         self.critic.Q2,
                                                         batch_size=1)

                    Q1, Q2 = self.critic(state, action)
                    Q = torch.min(Q1, Q2)

                    better_Q1, better_Q2 = self.critic(state, better_action)
                    better_Q = torch.min(better_Q1, better_Q2)

                    action_index = (Q > better_Q).squeeze()
                    better_action[action_index] = action[action_index]
                else:
                    better_action = action
            return better_action.cpu().data.numpy().flatten()

        else:
            _, _, action, _ = self.actor.forward_all(state)
            return action.cpu().data.numpy().flatten()

    def lr_scheduler(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        return optimizer

    def update_critic(self, critic_loss):
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(),
                   filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(),
                   filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(
            torch.load(filename + "_critic_optimizer"))
        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(
            torch.load(filename + "_actor_optimizer"))

    def train(self,
              replay_buffer,
              batch_size=100,
              writer=None,
              reward_range=20.0):
        self.total_it += 1
        log_it = (self.total_it % self.log_freq == 0)
        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(
            batch_size)

        with torch.no_grad():

            # Select action according to policy and add clipped noise
            next_action = (self.actor(next_state)).clamp(
                -self.max_action, self.max_action)
            better_next_action = self.searcher.search(next_state, next_action,
                                                      self.critic.Q2)

            target_Q1, target_Q2 = self.critic(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)

            better_target_Q1, better_target_Q2 = self.critic(
                next_state, better_next_action)
            better_target_Q = torch.min(better_target_Q1, better_target_Q2)

            action_index = (target_Q > better_target_Q).squeeze()
            better_next_action[action_index] = next_action[action_index]
            better_target_Q1, better_target_Q2 = self.critic(
                next_state, better_next_action)

            better_target_Q = torch.max(better_target_Q, target_Q)

            target_Q_final = reward + not_done * self.discount * better_target_Q

            target_Q1 = better_target_Q1
            target_Q2 = better_target_Q2
            next_action = better_next_action

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q_final) + F.mse_loss(
            current_Q2, target_Q_final)
        self.update_critic(critic_loss)

        current_Q1_, current_Q2_ = self.critic(state, action)
        target_Q1_, target_Q2_ = self.critic(next_state, next_action)
        critic_loss3 = F.mse_loss(current_Q1_, target_Q_final) + F.mse_loss(
            current_Q2_, target_Q_final) + F.mse_loss(
                target_Q1_, target_Q1) + F.mse_loss(target_Q2_, target_Q2)
        self.update_critic(critic_loss3)
        init_critic_loss3 = critic_loss3.clone()
        ratio = 0.0
        max_step = 0

        idi = 0
        cond1 = 0
        cond2 = 0
        while True:
            idi = idi + 1
            current_Q1_, current_Q2_ = self.critic(state, action)
            target_Q1_, target_Q2_ = self.critic(next_state, next_action)
            critic_loss3 = F.mse_loss(
                current_Q1_, target_Q_final) + F.mse_loss(
                    current_Q2_, target_Q_final) + F.mse_loss(
                        target_Q1_, target_Q1) + F.mse_loss(
                            target_Q2_, target_Q2)
            self.update_critic(critic_loss3)
            if self.total_it < self.max_timesteps:
                bound = self.third_loss_bound + float(self.total_it) / float(
                    self.max_timesteps) * (self.third_loss_bound_end -
                                           self.third_loss_bound)
            else:
                bound = self.third_loss_bound_end
            if critic_loss3 < init_critic_loss3 * bound:
                cond1 = 1
                break
            if idi >= self.max_iter_steps:
                cond2 = 1
                break
        critic_loss = F.mse_loss(current_Q1, target_Q_final) + F.mse_loss(
            current_Q2, target_Q_final)
        weights_actor_lr = critic_loss.detach()

        if self.total_it % 1 == 0:
            lr_tmp = self.actor_lr / (float(weights_actor_lr) + 1.0)
            self.actor_optimizer = self.lr_scheduler(self.actor_optimizer,
                                                     lr_tmp)

            # Compute actor loss
            actor_action, log_prob, action_mean, action_sigma = self.actor.forward_all(
                state)
            q_actor_action = self.critic.Q1(state, actor_action)
            m = Normal(action_mean, action_sigma)

            better_action = self.searcher.search(state,
                                                 actor_action,
                                                 self.critic.Q1,
                                                 batch_size=batch_size)  #####
            q_better_action = self.critic.Q1(state, better_action)
            log_prob_better_action = m.log_prob(better_action).sum(
                1, keepdim=True)

            adv = (q_better_action - q_actor_action).detach()
            adv = torch.max(adv, torch.zeros_like(adv))
            cem_loss = log_prob_better_action * torch.min(
                reward_range * torch.ones_like(adv), adv)
            actor_loss = -(cem_loss * self.cem_loss_coef +
                           q_actor_action).mean()

            # Optimize the actor
            Q_before_update = self.critic.Q1(state, actor_action)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            self.actor_optimizer = self.lr_scheduler(self.actor_optimizer,
                                                     self.actor_lr)

        if log_it:
            with torch.no_grad():
                writer.add_scalar('train_critic/third_loss_cond1', cond1,
                                  self.total_it)
                writer.add_scalar('train_critic/third_loss_num', idi,
                                  self.total_it)
                writer.add_scalar('train_critic/critic_loss', critic_loss,
                                  self.total_it)
                writer.add_scalar('train_critic/critic_loss3', critic_loss3,
                                  self.total_it)

                target_Q1_Q2_diff = target_Q1 - target_Q2
                writer.add_scalar('q_diff/target_Q1_Q2_diff_max',
                                  target_Q1_Q2_diff.max(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_min',
                                  target_Q1_Q2_diff.min(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_mean',
                                  target_Q1_Q2_diff.mean(), self.total_it)
                writer.add_scalar('q_diff/target_Q1_Q2_diff_abs_mean',
                                  target_Q1_Q2_diff.abs().mean(),
                                  self.total_it)

                current_Q1_Q2_diff = current_Q1 - current_Q2
                writer.add_scalar('q_diff/current_Q1_Q2_diff_max',
                                  current_Q1_Q2_diff.max(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_min',
                                  current_Q1_Q2_diff.min(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_mean',
                                  current_Q1_Q2_diff.mean(), self.total_it)
                writer.add_scalar('q_diff/current_Q1_Q2_diff_abs_mean',
                                  current_Q1_Q2_diff.abs().mean(),
                                  self.total_it)

                #target_Q1
                writer.add_scalar('train_critic/target_Q1/mean',
                                  torch.mean(target_Q1), self.total_it)
                writer.add_scalar('train_critic/target_Q1/max',
                                  target_Q1.max(), self.total_it)
                writer.add_scalar('train_critic/target_Q1/min',
                                  target_Q1.min(), self.total_it)
                writer.add_scalar('train_critic/target_Q1/std',
                                  torch.std(target_Q1), self.total_it)

                #target_Q2
                writer.add_scalar('train_critic/target_Q2/mean',
                                  torch.mean(target_Q2), self.total_it)

                #current_Q1
                writer.add_scalar('train_critic/current_Q1/mean',
                                  current_Q1.mean(), self.total_it)
                writer.add_scalar('train_critic/current_Q1/std',
                                  torch.std(current_Q1), self.total_it)
                writer.add_scalar('train_critic/current_Q1/max',
                                  current_Q1.max(), self.total_it)
                writer.add_scalar('train_critic/current_Q1/min',
                                  current_Q1.min(), self.total_it)

                # current_Q2
                writer.add_scalar('train_critic/current_Q2/mean',
                                  current_Q2.mean(), self.total_it)

    def save(self, filename):
        super().save(filename)

    def load(self, filename):
        super().load(filename)

    def make_Q_contour(self, state, save_folder, base_action):
        super().make_Q_contour(state, save_folder, base_action)
Example #5
0
class GRAC():
	def __init__(
		self,
		state_dim,
		action_dim,
		max_action,
		batch_size=256,
		discount=0.99,
		max_timesteps=3e6,
                actor_lr = 3e-4,
                critic_lr = 3e-4,
                loss_decay = 0.95,
                log_freq = 200,
		cem_loss_coef =1.0,
		device=torch.device('cuda'),
	):
		self.action_dim = action_dim
		self.state_dim = state_dim
		self.max_action = max_action
		self.discount = discount
		self.total_it = 0

		self.device = device
		self.actor_lr = actor_lr # here is actor lr is not the real actor learning rate
		self.critic_lr = critic_lr
		self.loss_decay = loss_decay

		self.actor = Actor(state_dim, action_dim, max_action).to(device)
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)

		self.critic = Critic(state_dim, action_dim).to(device)
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

		cem_sigma = 1e-2
		cem_clip = 0.5 * max_action
		self.cem_clip_init = cem_clip
		self.cem_clip = cem_clip
		self.searcher = Searcher(action_dim, max_action, device=device, sigma_init=cem_sigma, clip=cem_clip, batch_size=batch_size)
		self.action_dim = float(action_dim)
		self.log_freq = log_freq
		self.max_timesteps = max_timesteps
		self.cem_loss_coef = cem_loss_coef/float(self.action_dim)

	def select_action(self, state, writer=None, test=False):
		state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
		if test is False:
			with torch.no_grad():
				action = self.actor(state)
				better_action = action
			return better_action.cpu().data.numpy().flatten()
		else:
			_, _, action, _ = self.actor.forward_all(state)
			return action.cpu().data.numpy().flatten()

	def lr_scheduler(self, optimizer,lr):
		for param_group in optimizer.param_groups:
			param_group['lr'] = lr
		return optimizer

	def update_critic(self, critic_loss):
		self.critic_optimizer.zero_grad()
		critic_loss.backward()
		self.critic_optimizer.step()

	def save(self, filename):
		torch.save(self.critic.state_dict(), filename + "_critic")
		torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
		torch.save(self.actor.state_dict(), filename + "_actor")
		torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

	def load(self, filename):
		self.critic.load_state_dict(torch.load(filename + "_critic"))
		self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
		self.actor.load_state_dict(torch.load(filename + "_actor"))
		self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))

	def train(self, replay_buffer, batch_size=100, writer=None, reward_range=20.0, reward_max=0, episode_step_max=100, reward_min=0, episode_step_min=1):
		self.total_it += 1
		log_it = (self.total_it % self.log_freq == 0)
		ratio_it = max(1.0 - self.total_it/float(self.max_timesteps), 0.1)
		if log_it:
			writer.add_scalar('train_critic/ratio_it', ratio_it, self.total_it)
		self.cem_clip = self.cem_clip_init * ratio_it
		# Sample replay buffer 
		state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
		with torch.no_grad():
			Q_max = reward_max / (1 - self.discount) * (1 - self.discount ** int(episode_step_max))
			if reward_min >= 0:
				Q_min = reward_min / (1 - self.discount) * (1 - self.discount ** int(episode_step_min))
			else:
				Q_min = reward_min / (1 - self.discount) * (1 - self.discount ** int(episode_step_max))

			# Select action according to policy and add clipped noise
			#better_next_action = next_action1
			next_action1, next_action2, next_action_mean, next_action_sigma, prob1, prob2, probm= self.actor.forward_all_sample(next_state)
			coef1 = prob1 / (prob1 + prob2 + probm)
			coef2 = prob2 / (prob1 + prob2 + probm)
			coefm = probm / (prob1 + prob2 + probm)
			#print("next_action1",next_action1)
			#print("next_action2",next_action2)
			#print("prob1",prob1)
			#print("prob2",prob2)
			#next_action1, next_action2 = self.actor.forward_sample(next_state)
			#better_next_action,_ = self.searcher.search(next_state, next_action, self.critic.Q2, clip=self.cem_clip)
			#next_action2 = next_action_mean	
			#next_action1 = next_action
			#better_next_action = next_action1

			target_Q1 = self.critic(next_state, next_action1)
			target_Q2 = self.critic(next_state, next_action2)
			target_mean = self.critic(next_state, next_action_mean)
			target_Q1[target_Q1 > Q_max] = Q_max
			target_Q1[target_Q1 < Q_min] = Q_min
			target_Q2[target_Q2 > Q_max] = Q_max
			target_Q2[target_Q2 < Q_min] = Q_min
			target_mean[target_mean > Q_max] = Q_max
			target_mean[target_mean < Q_min] = Q_min

			#target_Q_w = target_Q1 * coef1 + target_Q2 * coef2 + target_mean * coefm
			#target_Q = torch.min(target_Q_w, target_mean)
			target_Q = torch.min(target_Q1, target_mean)
			target_Q = torch.max(target_mean - torch.ones_like(target_mean) * Q_max * 0.05, target_Q)
			target_Q_mean_diff = target_mean - target_Q
			target_Q1_diff = target_Q1 - target_Q
			target_Q2_diff = target_Q2 - target_Q	
			#action_index = (target_Q1 > target_Q2).squeeze()
			#better_next_action[action_index] = next_action2[action_index]
			#better_target_Q = self.critic(next_state, better_next_action)
			#target_Q1 = target_Q

			target_Q[target_Q > Q_max] = Q_max
			target_Q[target_Q < Q_min] = Q_min
			
			target_Q_final = reward + not_done * self.discount * target_Q
			target_Q_final[target_Q_final > Q_max] = Q_max
			target_Q_final[target_Q_final < Q_min] = Q_min

			target_Q1 = target_Q
			next_action = next_action_mean
		# Get current Q estimates
		current_Q1 = self.critic(state, action)

		# Compute critic loss
		critic_loss = F.mse_loss(current_Q1, target_Q_final)
		self.update_critic(critic_loss)

		current_Q1_ = self.critic(state, action)
		target_Q1_ = self.critic(next_state, next_action)
		critic_loss2_1 = F.mse_loss(current_Q1_, target_Q_final)
		critic_loss2_2 = F.mse_loss(target_Q1_, target_Q1)
		weight1 = critic_loss2_1.item()
		weight2 = critic_loss2_2.item()	
		weight_loss = (math.sqrt(weight1) + 1.)/( math.sqrt(weight2) + 1.)
		critic_loss3 = critic_loss2_1 + critic_loss2_2 * weight_loss
		self.update_critic(critic_loss3)
		init_critic_loss3 = critic_loss3.clone()
		ratio = 0.0
		max_step = 0

		idi = 0
		while True:
			idi = idi + 1
			current_Q1_ = self.critic(state, action)
			target_Q1_ = self.critic(next_state, next_action)
			critic_loss3_1 = F.mse_loss(current_Q1_, target_Q_final)
			critic_loss3_2 = F.mse_loss(target_Q1_, target_Q1)
			critic_loss3 = critic_loss3_1 + critic_loss3_2 * weight_loss
			self.update_critic(critic_loss3)
			if critic_loss3_1 < critic_loss * self.loss_decay and critic_loss3_1 < critic_loss2_1 * self.loss_decay:# and torch.sqrt(critic_loss3_2) < torch.max(torch.mean(torch.abs(target_Q)) * 0.01, torch.mean(torch.abs(reward))):
                                        break
			if idi > 50:
				break
		#critic_loss = F.mse_loss(current_Q1, target_Q_final) + F.mse_loss(current_Q2, target_Q_final)
		weights_actor_lr = critic_loss2_1.item()
		if log_it:
			writer.add_scalar('train_critic/weight_loss', weight_loss, self.total_it)
			writer.add_scalar('train_critic/third_loss_num', idi, self.total_it)
			writer.add_scalar('train_critic/Q_max', Q_max, self.total_it)
			writer.add_scalar('train_critic/episode_step_max',episode_step_max, self.total_it)
			writer.add_scalar('train_critic/Q_min', Q_min, self.total_it)
			writer.add_scalar('train_critic/cem_clip', self.cem_clip, self.total_it)
			#writer.add_scalar('train_critic/Q_min_mean', torch.mean(Q_min), self.total_it)
			#writer.add_scalar('train_critic/Q_min_min', torch.min(Q_min), self.total_it)
			writer.add_scalar('train_critic/episode_step_min',episode_step_min, self.total_it)
		if log_it:
			writer.add_scalar('train_loss/loss2_1',critic_loss2_1,self.total_it)
			writer.add_scalar('train_loss/loss2_2',critic_loss2_2,self.total_it)
			writer.add_scalar('train_loss/loss3_1_r',critic_loss3_1/critic_loss2_1,self.total_it)
			writer.add_scalar('train_loss/loss3_2_r',critic_loss3_2/critic_loss2_2,self.total_it)
			writer.add_scalar('train_loss/loss3_1_r_loss',critic_loss3_1/critic_loss,self.total_it)
			writer.add_scalar('train_loss/sqrt_critic_loss3_2',torch.sqrt(critic_loss3_2),self.total_it)
			writer.add_scalar('train_loss/max_reward',reward_max,self.total_it)	
			#writer.add_scalar('train_loss/min_reward',reward_min,self.total_it)
		if self.total_it % 1 == 0:
			#lr_tmp = self.actor_lr / (float(weights_actor_lr)+1.0) * self.actor_lr_ratio
			#self.actor_optimizer = self.lr_scheduler(self.actor_optimizer, lr_tmp)

			# Compute actor loss
			actor_action, log_prob, action_mean, action_sigma = self.actor.forward_all(state)
			q_actor_action = self.critic.Q1(state, actor_action)
			m = Normal(action_mean, action_sigma)

			better_action,_ = self.searcher.search(state, actor_action, self.critic.Q1, batch_size=batch_size, clip=self.cem_clip)#####
			q_better_action = self.critic.Q1(state, better_action)
			
			q_better_action[q_better_action > Q_max] = Q_max
			q_better_action[q_better_action < Q_min] = Q_min
			q_actor_action[q_actor_action > Q_max] = Q_max
			q_actor_action[q_actor_action < Q_min] = Q_min

			log_prob_better_action = m.log_prob(better_action).sum(1,keepdim=True)

			adv = (q_better_action - q_actor_action).detach()
			adv = torch.min(adv,torch.ones_like(adv) * Q_max * 0.05)
			cem_loss = log_prob_better_action * adv#torch.min(reward_range * torch.ones_like(adv) * ratio_it, adv)
			actor_loss = -(cem_loss * self.cem_loss_coef + q_actor_action).mean()

			# Optimize the actor 
			Q_before_update = self.critic.Q1(state, actor_action)

			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()
			#self.actor_optimizer = self.lr_scheduler(self.actor_optimizer, self.actor_lr)

		if log_it:
			with torch.no_grad():
				writer.add_scalar('train_critic/third_loss_num', idi, self.total_it)
				writer.add_scalar('train_critic/critic_loss', critic_loss, self.total_it)
				writer.add_scalar('train_critic/critic_loss3', critic_loss3, self.total_it)
	
				#target_Q1
				writer.add_scalar('train_critic/target_Q1/mean', torch.mean(target_Q1), self.total_it)
				writer.add_scalar('train_critic/target_Q1/max', target_Q1.max(), self.total_it)
				writer.add_scalar('train_critic/target_Q1/min', target_Q1.min(), self.total_it)
				writer.add_scalar('train_critic/target_Q1/std', torch.std(target_Q1), self.total_it)
	
				#current_Q1
				writer.add_scalar('train_critic/current_Q1/mean', current_Q1.mean(), self.total_it)
				writer.add_scalar('train_critic/current_Q1/std', torch.std(current_Q1), self.total_it)
				writer.add_scalar('train_critic/current_Q1/max', current_Q1.max(), self.total_it)
				writer.add_scalar('train_critic/current_Q1/min', current_Q1.min(), self.total_it)

				# advantage 
				writer.add_scalar('train_critic/adv/mean', adv.mean(), self.total_it)
				writer.add_scalar('train_critic/adv/std', torch.std(adv), self.total_it)
				writer.add_scalar('train_critic/adv/max', adv.max(), self.total_it)
				writer.add_scalar('train_critic/adv/min', adv.min(), self.total_it)

				# targetQ1_diff
				writer.add_scalar('train_critic/target_Q1_diff/mean', target_Q1_diff.mean(), self.total_it)
				writer.add_scalar('train_critic/target_Q1_diff/std', torch.std(target_Q1_diff), self.total_it)
				writer.add_scalar('train_critic/target_Q1_diff/max', target_Q1_diff.max(), self.total_it)
				writer.add_scalar('train_critic/target_Q1_diff/min', target_Q1_diff.min(), self.total_it)

				# targetQ2_diff
				writer.add_scalar('train_critic/target_Q2_diff/mean', target_Q2_diff.mean(), self.total_it)
				writer.add_scalar('train_critic/target_Q2_diff/std', torch.std(target_Q2_diff), self.total_it)
				writer.add_scalar('train_critic/target_Q2_diff/max', target_Q2_diff.max(), self.total_it)
				writer.add_scalar('train_critic/target_Q2_diff/min', target_Q2_diff.min(), self.total_it)
	
				
				# target_Q_mean_diff
				writer.add_scalar('train_critic/target_Q_mean_diff/mean', target_Q_mean_diff.mean(), self.total_it)
				writer.add_scalar('train_critic/target_Q_mean_diff/std', torch.std(target_Q_mean_diff), self.total_it)
				writer.add_scalar('train_critic/target_Q_mean_diff/max', target_Q_mean_diff.max(), self.total_it)
				writer.add_scalar('train_critic/target_Q_mean_diff/min', target_Q_mean_diff.min(), self.total_it)


	
	def save(self, filename):
		torch.save(self.critic.state_dict(), filename + "_critic")
		torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
		torch.save(self.actor.state_dict(), filename + "_actor")
		torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

	def load(self, filename):
		self.critic.load_state_dict(torch.load(filename + "_critic"))
		self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
		self.actor.load_state_dict(torch.load(filename + "_actor"))
		self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))