Пример #1
0
    def __init__(self,
                 env_name,
                 env,
                 critic_lr=3e-4,
                 train_iters=20,
                 backtrack_coeff=1,
                 backtrack_damp_coeff=0.5,
                 backtrack_alpha=0.5,
                 delta=0.01,
                 sample_size=2048,
                 gamma=0.99,
                 lam=0.97,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=20):
        self.env_name = env_name
        self.env = env
        self.critic_lr = critic_lr
        self.train_iters = train_iters
        self.backtrack_coeff = backtrack_coeff
        self.backtrack_damp_coeff = backtrack_damp_coeff
        self.backtrack_alpha = backtrack_alpha
        self.sample_size = sample_size
        self.delta = delta
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.total_step = 0
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        # self.device = torch.device('cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/TD3_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss

        n_state, n_action = env.observation_space.shape[
            0], env.action_space.shape[0]
        self.old_policy = GaussianActor(
            n_state, n_action, 128,
            action_scale=int(env.action_space.high[0])).to(self.device)
        self.new_policy = GaussianActor(
            n_state, n_action, 128,
            action_scale=int(env.action_space.high[0])).to(self.device)
        update_model(self.old_policy, self.new_policy)
        self.critic = Critic(n_state, 128).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.trace = Trace()

        print(self.new_policy)
        print(self.critic)
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-3,
                 gamma=0.99,
                 batch_size=32,
                 replay_memory_size=1e6,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.gamma = gamma
        self.batch_size = batch_size
        self.replay_memory_size = replay_memory_size
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/DDPG_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss
        self.memory = Memory(int(replay_memory_size), batch_size)

        n_state, n_action = env.observation_space.shape[0], env.action_space.shape[0]
        self.noise = OUNoise(n_action)

        self.actor = DeterministicActor(n_state, n_action,
                                        action_scale=int(env.action_space.high[0])).to(self.device)
        self.target_actor = DeterministicActor(n_state, n_action,
                                               action_scale=int(env.action_space.high[0])).to(self.device)
        update_model(self.target_actor, self.actor)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)

        self.critic = Critic(n_state + n_action).to(self.device)
        self.target_critic = Critic(n_state + n_action).to(self.device)
        update_model(self.target_critic, self.critic)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        print(self.actor)
        print(self.critic)
Пример #3
0
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 sample_size=2048,
                 gamma=0.99,
                 lam=0.95,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.sample_size = sample_size
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.total_step = 0
        self.state_normalize = ZFilter(env.observation_space.shape[0])
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs_epoch_update/A2C_{}'.format(
                self.env_name))
        self.loss_fn = F.smooth_l1_loss

        self.trace = Trace()
        self.actor = GaussianActor(
            env.observation_space.shape[0],
            env.action_space.shape[0],
            action_scale=int(env.action_space.high[0])).to(self.device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic = Critic(env.observation_space.shape[0]).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        print(self.actor)
        print(self.critic)
Пример #4
0
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-3,
                 gamma=0.99,
                 is_continue_action_space=False,
                 reward_shaping_func=lambda x: x[1],
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.gamma = gamma
        self.reward_shaping_func = reward_shaping_func
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs_step_update/A2C_{}'.format(
                self.env_name))
        self.loss_fn = F.mse_loss

        if is_continue_action_space:
            self.actor = GaussianActor(
                env.observation_space.shape[0],
                env.action_space.shape[0],
                action_scale=int(env.action_space.high[0])).to(self.device)
        else:
            self.actor = Actor(env.observation_space.shape[0],
                               env.action_space.n).to(self.device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic = Critic(env.observation_space.shape[0]).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        print(self.actor)
        print(self.critic)
Пример #5
0
class A2C:
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 sample_size=2048,
                 gamma=0.99,
                 lam=0.95,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.sample_size = sample_size
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.total_step = 0
        self.state_normalize = ZFilter(env.observation_space.shape[0])
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs_epoch_update/A2C_{}'.format(
                self.env_name))
        self.loss_fn = F.smooth_l1_loss

        self.trace = Trace()
        self.actor = GaussianActor(
            env.observation_space.shape[0],
            env.action_space.shape[0],
            action_scale=int(env.action_space.high[0])).to(self.device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic = Critic(env.observation_space.shape[0]).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        print(self.actor)
        print(self.critic)

    def select_action(self, state, is_test=False):
        state = torch.tensor(state,
                             dtype=torch.float).unsqueeze(0).to(self.device)
        a, log_prob = self.actor.sample(state, is_test)
        return a, log_prob

    def train(self, epochs):
        best_eval = -1e6
        for epoch in range(epochs):
            num_sample = 0
            self.trace.clear()
            s = self.env.reset()
            s = self.state_normalize(s)
            while num_sample < self.sample_size:
                self.env.render()
                a, log_prob = self.select_action(s)
                torch_s = torch.tensor(s, dtype=torch.float).unsqueeze(0).to(
                    self.device)
                v = self.critic(torch_s)
                s_, r, done, _ = self.env.step(a.cpu().detach().numpy()[0])
                s_ = self.state_normalize(s_)
                self.trace.push(s, a, log_prob, r, s_, not done,
                                v)  # 这里怎么写才能在learn里面不用reshape呢
                num_sample += 1
                self.total_step += 1
                s = s_
                if done:
                    s = self.env.reset()
                    s = self.state_normalize(s)

            policy_loss, critic_loss = self.learn()

            self.writer.add_scalar('loss/actor_loss', policy_loss,
                                   self.total_step)
            self.writer.add_scalar('loss/critic_loss', critic_loss,
                                   self.total_step)

            if (epoch + 1) % self.save_model_frequency == 0:
                save_model(
                    self.critic,
                    'model_epoch_update/{}_model/critic_{}'.format(
                        self.env_name, self.total_step))
                save_model(
                    self.actor, 'model_epoch_update/{}_model/actor_{}'.format(
                        self.env_name, self.total_step))

            if (epoch + 1) % self.eval_frequency == 0:
                eval_r = self.evaluate()
                print('epoch', epoch, 'evaluate reward', eval_r)
                self.writer.add_scalar('reward', eval_r, epoch)
                if eval_r > best_eval:
                    best_eval = eval_r
                    save_model(
                        self.critic,
                        'model_epoch_update/{}_model/best_critic'.format(
                            self.env_name))
                    save_model(
                        self.actor,
                        'model_epoch_update/{}_model/best_actor'.format(
                            self.env_name))

    def learn(self):
        all_data = self.trace.get()
        all_state = torch.tensor(all_data.state,
                                 dtype=torch.float).to(self.device)
        all_log_prob = torch.cat(all_data.log_prob).to(self.device)

        adv, total_reward = self.trace.cal_advantage(self.gamma, self.lam)
        adv = adv.reshape(len(self.trace), -1).to(self.device)
        total_reward = total_reward.reshape(len(self.trace),
                                            -1).to(self.device)

        all_value = self.critic(all_state)
        critic_loss = self.loss_fn(all_value, total_reward.detach())
        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()

        policy_loss = (-all_log_prob * adv.detach() +
                       0.01 * all_log_prob.exp() * all_log_prob).mean()

        self.actor_opt.zero_grad()
        policy_loss.backward()
        self.actor_opt.step()

        return policy_loss.item(), critic_loss.item()

    def evaluate(self, epochs=3, is_render=False):
        eval_r = 0
        for _ in range(epochs):
            s = self.env.reset()
            s = self.state_normalize(s, update=False)
            while True:
                if is_render:
                    self.env.render()
                with torch.no_grad():
                    a, _ = self.select_action(s, is_test=True)
                s_, r, done, _ = self.env.step(a.cpu().detach().numpy()[0])
                s_ = self.state_normalize(s_, update=False)
                s = s_
                eval_r += r
                if done:
                    break
        return eval_r / epochs
Пример #6
0
class TRPO:
    def __init__(self,
                 env_name,
                 env,
                 critic_lr=3e-4,
                 train_iters=20,
                 backtrack_coeff=1,
                 backtrack_damp_coeff=0.5,
                 backtrack_alpha=0.5,
                 delta=0.01,
                 sample_size=2048,
                 gamma=0.99,
                 lam=0.97,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=20):
        self.env_name = env_name
        self.env = env
        self.critic_lr = critic_lr
        self.train_iters = train_iters
        self.backtrack_coeff = backtrack_coeff
        self.backtrack_damp_coeff = backtrack_damp_coeff
        self.backtrack_alpha = backtrack_alpha
        self.sample_size = sample_size
        self.delta = delta
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.total_step = 0
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        # self.device = torch.device('cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/TD3_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss

        n_state, n_action = env.observation_space.shape[
            0], env.action_space.shape[0]
        self.old_policy = GaussianActor(
            n_state, n_action, 128,
            action_scale=int(env.action_space.high[0])).to(self.device)
        self.new_policy = GaussianActor(
            n_state, n_action, 128,
            action_scale=int(env.action_space.high[0])).to(self.device)
        update_model(self.old_policy, self.new_policy)
        self.critic = Critic(n_state, 128).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.trace = Trace()

        print(self.new_policy)
        print(self.critic)

    def select_action(self, state, is_test=False):
        state = torch.tensor(state, dtype=torch.float).to(self.device)
        a, _ = self.new_policy.sample(state, is_test)
        return a.cpu().detach().numpy()

    def train(self, epochs):
        best_eval = -1e6
        for epoch in range(epochs):
            self.trace.clear()
            num_sample = 0
            s = self.env.reset()
            # collect data
            while num_sample < self.sample_size:
                self.env.render()
                a = self.select_action(s)
                s_, r, done, _ = self.env.step(a)
                v = self.critic(
                    torch.tensor(s, dtype=torch.float).to(self.device))
                self.trace.push(s, a, r, s_, done, v)
                num_sample += 1

                s = s_
                if done:
                    s = self.env.reset()

            self.learn()
            if (epoch + 1) % self.eval_frequency == 0:
                eval_r = self.evaluate()
                print('epoch', epoch, 'evaluate reward', eval_r)
                self.writer.add_scalar('reward', eval_r, self.total_step)
                # if eval_r > best_eval:
                #     best_eval = eval_r
                #     save_model(self.critic, 'model/{}_model/best_critic'.format(self.env_name))
                #     save_model(self.actor, 'model/{}_model/best_actor'.format(self.env_name))
                #     ZFilter.save(self.state_normalize, 'model/{}_model/best_rs'.format(self.env_name))

    def learn(self):
        state, action, reward, next_state, done, value = self.trace.get()
        advantage, total_reward = self.trace.cal_advantage(
            self.gamma, self.lam)
        action = torch.tensor(action, dtype=torch.float).to(self.device)
        state = torch.tensor(state, dtype=torch.float).to(self.device)
        advantage = torch.tensor(advantage, dtype=torch.float).to(self.device)
        total_reward = torch.tensor(total_reward,
                                    dtype=torch.float).to(self.device)

        # update critic
        for _ in range(self.train_iters):
            value = self.critic(state).squeeze(1)
            critic_loss = self.loss_fn(value, total_reward)
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

        # update policy
        log_prob_old = self.new_policy.get_log_prob(state, action)
        log_prob_new = self.new_policy.get_log_prob(state, action)
        ratio_old = torch.exp(log_prob_new - log_prob_old.detach())
        policy_loss_old = (ratio_old * advantage).mean()

        gradient = torch.autograd.grad(policy_loss_old,
                                       self.new_policy.parameters())
        gradient = TRPO.flatten_tuple(gradient)

        x = self.cg(state, gradient)
        gHg = (self.get_hessian_dot_vec(state, x) * x).sum(0)
        step_size = torch.sqrt(2 * self.delta / (gHg + 1e-8))
        old_params = self.flatten_tuple(self.new_policy.parameters())
        update_model(self.old_policy, self.new_policy)

        # backtracking line search
        expected_improve = (gradient * step_size * x).sum()
        print(expected_improve)
        tmp_backtrack_coeff = self.backtrack_coeff
        for _ in range(self.train_iters):
            new_params = old_params + self.backtrack_coeff * step_size * x
            idx = 0
            for param in self.new_policy.parameters():
                param_len = len(param.view(-1))
                new_param = new_params[idx:idx + param_len]
                new_param = new_param.view(param.size())
                param.data.copy_(new_param)
                idx += param_len

            log_porb = self.new_policy.get_log_prob(state, action)
            ratio = torch.exp(log_porb - log_prob_old)
            policy_loss = (ratio * advantage).mean()
            loss_improve = policy_loss - policy_loss_old
            expected_improve *= tmp_backtrack_coeff
            imporve_condition = (loss_improve /
                                 (expected_improve + 1e-8)).item()

            kl = (self.kl_divergence(self.old_policy, self.new_policy,
                                     state)).mean()

            if kl < self.delta and imporve_condition > self.backtrack_alpha:
                break

            tmp_backtrack_coeff *= self.backtrack_damp_coeff

    def cg(self, state, g, n_iters=20):
        # conjugate gradient algorithm to solve linear equation: Hx=g, H is symmetric and positive-definite
        # repeat:
        #   alpha_k = (r_k.T * r_k) / (p_k.T * A * p_k)
        #   x_k+1 = x_k + alpha_k * p_k
        #   r_k+1 = r_k + alpha_k * A * p_k
        #   beta_k+1 = (r_k+1.T *  r_k+1) / (r_k.T * r_k)
        #   p_k+1 = -r_k+1 + beta_k+1 * p_k
        #   k = k + 1
        # end repeat
        x = torch.zeros_like(g).to(self.device)
        r = g.clone().to(self.device)
        p = g.clone().to(self.device)
        rdotr = torch.dot(r, r).to(self.device)

        for _ in range(n_iters):
            Hp = self.get_hessian_dot_vec(state, p)
            alpha = rdotr / (torch.dot(p, Hp) + 1e-8)
            x += alpha * p
            r += alpha * Hp
            new_rdotr = torch.dot(r, r)
            beta = new_rdotr / (rdotr + 1e-8)
            p = r + beta * p
            rdotr = new_rdotr
        return x

    def kl_divergence(self, old_policy, new_policy, state):
        # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
        # KL(p, q) = log(sigma_2 / sigma_1) + (sigma_1^2 + (mu_1 - mu_2)^2) / (2*sigma_2^2) - 0.5
        state = torch.as_tensor(state, dtype=torch.float).to(self.device)

        mu_1, log_sigma_1 = old_policy(state)
        mu_1 = mu_1.detach()
        sigma_1 = log_sigma_1.exp().detach()

        mu_2, log_sigma_2 = new_policy(state)
        sigma_2 = log_sigma_2.exp()
        kl = torch.log(sigma_2 / sigma_1) + (
            sigma_1.pow(2) +
            (mu_1 - mu_2).pow(2)) / (2 * sigma_2.pow(2) + 1e-8) - 0.5

        return kl

    def get_hessian_dot_vec(self, state, vec, damping_coeff=0.01):
        kl = self.kl_divergence(self.old_policy, self.new_policy, state)
        kl_mean = kl.mean()
        gradient = torch.autograd.grad(kl_mean,
                                       self.new_policy.parameters(),
                                       create_graph=True)
        gradient = self.flatten_tuple(gradient)

        kl_grad_p = (gradient * vec).sum()
        kl_hessian = torch.autograd.grad(kl_grad_p,
                                         self.new_policy.parameters())
        kl_hessian = self.flatten_tuple(kl_hessian)

        return kl_hessian + damping_coeff * vec

    @staticmethod
    def flatten_tuple(t):
        flatten_t = torch.cat([data.view(-1) for data in t])
        return flatten_t

    def evaluate(self, epochs=3, is_render=False):
        eval_r = 0
        for _ in range(epochs):
            s = self.env.reset()
            # s = self.state_normalize(s, update=False)
            while True:
                if is_render:
                    self.env.render()
                a = self.select_action(s, is_test=True)
                s_, r, done, _ = self.env.step(a)
                # s_ = self.state_normalize(s_, update=False)
                s = s_
                eval_r += r
                if done:
                    break
        return eval_r / epochs
Пример #7
0
class A2C:
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-3,
                 gamma=0.99,
                 is_continue_action_space=False,
                 reward_shaping_func=lambda x: x[1],
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.gamma = gamma
        self.reward_shaping_func = reward_shaping_func
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs_step_update/A2C_{}'.format(
                self.env_name))
        self.loss_fn = F.mse_loss

        if is_continue_action_space:
            self.actor = GaussianActor(
                env.observation_space.shape[0],
                env.action_space.shape[0],
                action_scale=int(env.action_space.high[0])).to(self.device)
        else:
            self.actor = Actor(env.observation_space.shape[0],
                               env.action_space.n).to(self.device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic = Critic(env.observation_space.shape[0]).to(self.device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        print(self.actor)
        print(self.critic)

    def select_action(self, state, is_test=False):
        state = torch.tensor(state, dtype=torch.float).to(self.device)
        a, log_prob = self.actor.sample(state, is_test)
        return a.cpu().detach().numpy(), log_prob

    def train(self, epochs):
        best_eval = -1e6
        for epoch in range(epochs):
            s = self.env.reset()
            while True:
                self.env.render()
                a, log_prob = self.select_action(s)
                s_, r, done, _ = self.env.step(a)
                r = self.reward_shaping_func((s_, r, done, _))
                policy_loss, critic_loss = self.learn(s, a, log_prob, r, s_,
                                                      done)
                s = s_
                if done:
                    break

            self.writer.add_scalar('loss/actor_loss', policy_loss, epoch)
            self.writer.add_scalar('loss/critic_loss', critic_loss, epoch)

            if (epoch + 1) % self.save_model_frequency == 0:
                save_model(
                    self.critic, 'model_step_update/{}_model/critic_{}'.format(
                        self.env_name, epoch))
                save_model(
                    self.actor, 'model_step_update/{}_model/actor_{}'.format(
                        self.env_name, epoch))

            if (epoch + 1) % self.eval_frequency == 0:
                eval_r = self.evaluate()
                print('epoch', epoch, 'evaluate reward', eval_r)
                self.writer.add_scalar('reward', eval_r, epoch)
                if eval_r > best_eval:
                    best_eval = eval_r
                    save_model(
                        self.critic,
                        'model_step_update/{}_model/best_critic'.format(
                            self.env_name))
                    save_model(
                        self.actor,
                        'model_step_update/{}_model/best_actor'.format(
                            self.env_name))

    def learn(self, s, a, log_prob, r, s_, done):
        mask = not done
        next_q = self.critic(
            torch.tensor(s_, dtype=torch.float).to(self.device))
        target_q = r + mask * self.gamma * next_q
        pred_v = self.critic(
            torch.tensor(s, dtype=torch.float).to(self.device))
        critic_loss = self.loss_fn(pred_v, target_q.detach())

        self.critic_opt.zero_grad()
        critic_loss.backward()
        # for p in filter(lambda p: p.grad is not None, self.critic.parameters()):
        #     p.grad.data.clamp_(min=-1, max=1)
        self.critic_opt.step()

        advantage = (target_q - pred_v).detach()
        policy_loss = -advantage * log_prob + 0.01 * log_prob.exp() * log_prob

        self.actor_opt.zero_grad()
        policy_loss.backward()
        # for p in filter(lambda p: p.grad is not None, self.actor.parameters()):
        #     p.grad.data.clamp_(min=-1, max=1)
        self.actor_opt.step()
        # print(policy_loss, critic_loss, policy_loss.item(), critic_loss.item())
        return policy_loss.item(), critic_loss.item()

    def evaluate(self, epochs=3, is_render=False):
        eval_r = 0
        for _ in range(epochs):
            s = self.env.reset()
            while True:
                if is_render:
                    self.env.render()
                with torch.no_grad():
                    a, _ = self.select_action(s, is_test=True)
                s_, r, done, _ = self.env.step(a)
                s = s_
                eval_r += r
                if done:
                    break
        return eval_r / epochs
class DDPG:

    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-3,
                 gamma=0.99,
                 batch_size=32,
                 replay_memory_size=1e6,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=10):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.gamma = gamma
        self.batch_size = batch_size
        self.replay_memory_size = replay_memory_size
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/DDPG_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss
        self.memory = Memory(int(replay_memory_size), batch_size)

        n_state, n_action = env.observation_space.shape[0], env.action_space.shape[0]
        self.noise = OUNoise(n_action)

        self.actor = DeterministicActor(n_state, n_action,
                                        action_scale=int(env.action_space.high[0])).to(self.device)
        self.target_actor = DeterministicActor(n_state, n_action,
                                               action_scale=int(env.action_space.high[0])).to(self.device)
        update_model(self.target_actor, self.actor)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)

        self.critic = Critic(n_state + n_action).to(self.device)
        self.target_critic = Critic(n_state + n_action).to(self.device)
        update_model(self.target_critic, self.critic)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        print(self.actor)
        print(self.critic)

    def select_action(self, state, is_test=False):
        state = torch.tensor(state, dtype=torch.float).to(self.device)
        if is_test:
            a = self.actor(state)
        else:
            a = self.actor(state) + torch.tensor(self.noise(), dtype=torch.float).to(self.device)
            a = a.clip(-self.actor.action_scale, self.actor.action_scale)
        return a.cpu().detach().numpy()

    def train(self, epochs):
        best_eval = -1e6
        for epoch in range(epochs):
            s = self.env.reset()
            policy_loss, critic_loss = 0, 0
            while True:
                self.env.render()
                a = self.select_action(s)
                s_, r, done, _ = self.env.step(a)
                self.memory.push(s, a, r, s_, done)
                if len(self.memory) > self.batch_size:
                    policy_loss, critic_loss = self.learn()
                s = s_
                if done:
                    break

            self.writer.add_scalar('loss/actor_loss', policy_loss, epoch)
            self.writer.add_scalar('loss/critic_loss', critic_loss, epoch)

            if (epoch + 1) % self.save_model_frequency == 0:
                save_model(self.critic, 'model/{}_model/critic_{}'.format(self.env_name, epoch))
                save_model(self.actor, 'model/{}_model/actor_{}'.format(self.env_name, epoch))

            if (epoch + 1) % self.eval_frequency == 0:
                eval_r = self.evaluate()
                print('epoch', epoch, 'evaluate reward', eval_r)
                self.writer.add_scalar('reward', eval_r, epoch)
                if eval_r > best_eval:
                    best_eval = eval_r
                    save_model(self.critic, 'model/{}_model/best_critic'.format(self.env_name))
                    save_model(self.actor, 'model/{}_model/best_actor'.format(self.env_name))

    def learn(self):
        batch = self.memory.sample()
        batch_state, batch_action, batch_reward, batch_next_state, batch_done = \
            batch.state, batch.action, batch.reward, batch.next_state, batch.done
        batch_state = torch.tensor(batch_state, dtype=torch.float).to(self.device)
        batch_action = torch.tensor(batch_action, dtype=torch.float).reshape(self.batch_size, -1).to(self.device)
        batch_reward = torch.tensor(batch_reward, dtype=torch.float).reshape(self.batch_size, -1).to(self.device)
        batch_next_state = torch.tensor(batch_next_state, dtype=torch.float).to(self.device)
        batch_mask = torch.tensor([not i for i in batch_done], dtype=torch.bool).reshape(self.batch_size, -1).to(self.device)

        # update critic
        pred_q = self.critic(torch.cat((batch_state, batch_action), dim=-1))
        next_action = self.target_actor(batch_next_state)
        next_q = self.target_critic(torch.cat((batch_next_state, next_action), dim=-1))
        pred_target_q = batch_reward + batch_mask * self.gamma * next_q
        critic_loss = self.loss_fn(pred_q, pred_target_q)

        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()

        # update actor
        policy_loss = - self.critic(torch.cat((batch_state, self.actor(batch_state)), dim=-1)).mean()
        self.actor_opt.zero_grad()
        policy_loss.backward()
        self.actor_opt.step()

        # update target
        update_model(self.target_critic, self.critic, 0.05)
        update_model(self.target_actor, self.actor, 0.05)

        return policy_loss.item(), critic_loss.item()

    def evaluate(self, epochs=3, is_render=False):
        eval_r = 0
        for _ in range(epochs):
            s = self.env.reset()
            while True:
                if is_render:
                    self.env.render()
                with torch.no_grad():
                    a = self.select_action(s, is_test=True)
                s_, r, done, _ = self.env.step(a)
                s = s_
                eval_r += r
                if done:
                    break
        return eval_r / epochs
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 sample_size=2048,
                 batch_size=64,
                 sample_reuse=1,
                 train_iters=5,
                 clip=0.2,
                 gamma=0.99,
                 lam=0.95,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=5,
                 save_log_frequency=1):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.sample_size = sample_size
        self.batch_size = batch_size
        self.sample_reuse = sample_reuse
        self.train_iters = train_iters
        self.clip = clip
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency
        self.save_log_frequency = save_log_frequency

        self.total_step = 0
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/PPO_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss

        n_state, n_action = env.observation_space.shape[
            0], env.action_space.shape[0]
        self.state_normalize = ZFilter(n_state)
        self.actor = GaussianActor(n_state,
                                   n_action,
                                   128,
                                   action_scale=int(env.action_space.high[0]),
                                   weights_init_=orthogonal_weights_init_).to(
                                       self.device)
        self.critic = Critic(n_state, 128,
                             orthogonal_weights_init_).to(self.device)

        # self.optimizer = optim.Adam([
        #     {'params': self.critic.parameters(), 'lr': self.critic_lr},
        #     {'params': self.actor.parameters(), 'lr': self.actor_lr}
        # ])
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.trace = Trace()

        print(self.actor)
        print(self.critic)
class PPO:
    def __init__(self,
                 env_name,
                 env,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 sample_size=2048,
                 batch_size=64,
                 sample_reuse=1,
                 train_iters=5,
                 clip=0.2,
                 gamma=0.99,
                 lam=0.95,
                 is_test=False,
                 save_model_frequency=200,
                 eval_frequency=5,
                 save_log_frequency=1):
        self.env_name = env_name
        self.env = env
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.sample_size = sample_size
        self.batch_size = batch_size
        self.sample_reuse = sample_reuse
        self.train_iters = train_iters
        self.clip = clip
        self.gamma = gamma
        self.lam = lam
        self.save_model_frequency = save_model_frequency
        self.eval_frequency = eval_frequency
        self.save_log_frequency = save_log_frequency

        self.total_step = 0
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        print('Train on device:', self.device)
        if not is_test:
            self.writer = SummaryWriter('./logs/PPO_{}'.format(self.env_name))
        self.loss_fn = F.mse_loss

        n_state, n_action = env.observation_space.shape[
            0], env.action_space.shape[0]
        self.state_normalize = ZFilter(n_state)
        self.actor = GaussianActor(n_state,
                                   n_action,
                                   128,
                                   action_scale=int(env.action_space.high[0]),
                                   weights_init_=orthogonal_weights_init_).to(
                                       self.device)
        self.critic = Critic(n_state, 128,
                             orthogonal_weights_init_).to(self.device)

        # self.optimizer = optim.Adam([
        #     {'params': self.critic.parameters(), 'lr': self.critic_lr},
        #     {'params': self.actor.parameters(), 'lr': self.actor_lr}
        # ])
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.trace = Trace()

        print(self.actor)
        print(self.critic)

    def select_action(self, state, is_test=False):
        state = torch.tensor(state,
                             dtype=torch.float).unsqueeze(0).to(self.device)
        a, log_prob = self.actor.sample(state, is_test)
        return a.cpu().detach().numpy()[0], log_prob

    def train(self, epochs):
        best_eval = -1e6
        for epoch in range(epochs):
            num_sample = 0
            self.trace.clear()
            s = self.env.reset()
            s = self.state_normalize(s)
            while True:
                # self.env.render()
                a, log_prob = self.select_action(s)
                log_prob = torch.sum(log_prob, dim=1, keepdim=True)
                v = self.critic(
                    torch.tensor(s, dtype=torch.float).unsqueeze(0).to(
                        self.device))
                s_, r, done, _ = self.env.step(a)
                s_ = self.state_normalize(s_)
                self.trace.push(s, a,
                                log_prob.cpu().detach().numpy()[0], r, s_,
                                not done, v)
                num_sample += 1
                self.total_step += 1
                s = s_
                if done and num_sample >= self.sample_size:
                    break
                if done:
                    s = self.env.reset()
                    s = self.state_normalize(s)

            policy_loss, critic_loss = self.learn()

            if (epoch + 1) % self.save_log_frequency == 0:
                self.writer.add_scalar('loss/critic_loss', critic_loss,
                                       self.total_step)
                self.writer.add_scalar('loss/policy_loss', policy_loss,
                                       self.total_step)

            if (epoch + 1) % self.save_model_frequency == 0:
                save_model(
                    self.critic,
                    'model/{}_model/critic_{}'.format(self.env_name, epoch))
                save_model(
                    self.actor,
                    'model/{}_model/actor_{}'.format(self.env_name, epoch))
                ZFilter.save(
                    self.state_normalize,
                    'model/{}_model/rs_{}'.format(self.env_name, epoch))

            if (epoch + 1) % self.eval_frequency == 0:
                eval_r = self.evaluate()
                print('epoch', epoch, 'evaluate reward', eval_r)
                self.writer.add_scalar('reward', eval_r, self.total_step)
                if eval_r > best_eval:
                    best_eval = eval_r
                    save_model(
                        self.critic,
                        'model/{}_model/best_critic'.format(self.env_name))
                    save_model(
                        self.actor,
                        'model/{}_model/best_actor'.format(self.env_name))
                    ZFilter.save(
                        self.state_normalize,
                        'model/{}_model/best_rs'.format(self.env_name))

    def learn(self):
        all_data = self.trace.get()
        data_idx_range = np.arange(len(self.trace))
        adv, total_reward = self.trace.cal_advantage(self.gamma, self.lam)
        adv = adv.reshape(len(self.trace), -1).to(self.device)
        total_reward = total_reward.reshape(len(self.trace),
                                            -1).to(self.device)

        all_state = torch.tensor(all_data.state,
                                 dtype=torch.float).to(self.device)
        all_action = torch.tensor(all_data.action, dtype=torch.float).reshape(
            len(self.trace), -1).to(self.device)
        all_log_prob = torch.tensor(all_data.log_prob,
                                    dtype=torch.float).reshape(
                                        len(self.trace), -1).to(self.device)
        all_value = torch.tensor(all_data.value, dtype=torch.float).reshape(
            len(self.trace), -1).to(self.device)

        policy_loss, critic_loss = 0, 0
        # train_iters = max(int(self.sample_size * self.sample_reuse / self.batch_size), 1)

        for _ in range(self.train_iters):
            batch_idx = np.random.choice(data_idx_range,
                                         self.batch_size,
                                         replace=False)
            batch_state = all_state[batch_idx]
            batch_action = all_action[batch_idx]
            batch_log_prob = all_log_prob[batch_idx]
            batch_new_log_prob = self.actor.get_log_prob(
                batch_state, batch_action)
            batch_new_log_prob = batch_new_log_prob.sum(dim=1, keepdim=True)
            batch_old_value = all_value[batch_idx]
            batch_new_value = self.critic(batch_state)
            batch_adv = adv[batch_idx]
            batch_total_reward = total_reward[batch_idx]

            ratio = torch.exp(batch_new_log_prob - batch_log_prob)
            surr1 = ratio * batch_adv.detach()
            surr2 = ratio.clamp(1 - self.clip,
                                1 + self.clip) * batch_adv.detach()
            entropy_loss = (torch.exp(batch_new_log_prob) *
                            batch_new_log_prob).mean()
            policy_loss = -torch.min(surr1, surr2).mean() + 0.01 * entropy_loss

            self.actor_opt.zero_grad()
            policy_loss.backward()
            self.actor_opt.step()

            # print(batch_new_value, batch_total_reward)
            clip_v = batch_old_value + torch.clamp(
                batch_new_value - batch_old_value, -self.clip, self.clip)
            critic_loss = torch.max(
                self.loss_fn(batch_new_value, batch_total_reward),
                self.loss_fn(clip_v, batch_total_reward),
            )
            critic_loss = critic_loss.mean() / (6 * batch_total_reward.std())
            # critic_loss = self.loss_fn(batch_new_value, batch_total_reward)

            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # critic loss 太大了吧,这能优化吗
            # loss = policy_loss + 0.5 * critic_loss + 0.01 * entropy_loss
            #
            # self.optimizer.zero_grad()
            # loss.backward()
            # # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)  # self.max_grad_norm = 0.5
            # # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
            # self.optimizer.step()

        return policy_loss.item(), critic_loss.item()

    def evaluate(self, epochs=3, is_render=False):
        eval_r = 0
        for _ in range(epochs):
            s = self.env.reset()
            s = self.state_normalize(s, update=False)
            while True:
                if is_render:
                    self.env.render()
                a, _ = self.select_action(s, is_test=True)
                s_, r, done, _ = self.env.step(a)
                s_ = self.state_normalize(s_, update=False)
                s = s_
                eval_r += r
                if done:
                    break
        return eval_r / epochs