class TRPOAgent:

    TRAJECTORY_SIZE = 1024

    VF_BATCHSIZE = 64

    MAX_KL = 0.01

    GAMMA = 0.99

    GAE_LAMBDA = 0.98

    ENV_ID = "Pendulum-v0"

    OBS_SPACE = 3

    ACTION_SPACE = 1

    def __init__(self):

        self.policy = PolicyNetwork(action_space=self.ACTION_SPACE)

        self.value_network = ValueNetwork()

        self.env = gym.make(self.ENV_ID)

        self.global_steps = 0

        self.history = []

        self.hiscore = None

    def play(self, n_iters):

        self.epi_reward = 0

        self.epi_steps = 0

        self.state = self.env.reset()

        for _ in range(n_iters):

            trajectory = self.generate_trajectory()

            trajectory = self.compute_advantage(trajectory)

            self.update_policy(trajectory)

            self.update_vf(trajectory)

        return self.history

    def generate_trajectory(self):
        """generate trajectory on current policy
        """

        trajectory = {
            "s":
            np.zeros((self.TRAJECTORY_SIZE, self.OBS_SPACE), dtype=np.float32),
            "a":
            np.zeros((self.TRAJECTORY_SIZE, self.ACTION_SPACE),
                     dtype=np.float32),
            "r":
            np.zeros((self.TRAJECTORY_SIZE, 1), dtype=np.float32),
            "s2":
            np.zeros((self.TRAJECTORY_SIZE, self.OBS_SPACE), dtype=np.float32),
            "done":
            np.zeros((self.TRAJECTORY_SIZE, 1), dtype=np.float32)
        }

        state = self.state

        for i in range(self.TRAJECTORY_SIZE):

            action = self.policy.sample_action(state)

            next_state, reward, done, _ = self.env.step(action)

            trajectory["s"][i] = state

            trajectory["a"][i] = action

            trajectory["r"][i] = reward

            trajectory["s2"][i] = next_state

            trajectory["done"][i] = done

            self.epi_reward += reward

            self.epi_steps += 1

            self.global_steps += 1

            if done:
                state = self.env.reset()

                self.history.append(self.epi_reward)

                recent_score = sum(self.history[-10:]) / 10

                print("====" * 5)
                print("Episode:", len(self.history))
                print("Episode reward:", self.epi_reward)
                print("Global steps:", self.global_steps)

                if len(self.history) > 100 and (self.hiscore is None or
                                                recent_score > self.hiscore):
                    print("*HISCORE UPDATED:", recent_score)
                    self.save_model()
                    self.hiscore = recent_score

                self.epi_reward = 0

                self.epi_steps = 0

            else:
                state = next_state

        self.state = state

        return trajectory

    def compute_advantage(self, trajectory):
        """Compute

        Args:
            trajectory ([type]): [description]
        """

        trajectory["vpred"] = self.value_network(trajectory["s"]).numpy()

        trajectory["vpred_next"] = self.value_network(trajectory["s2"]).numpy()

        is_nonterminals = 1 - trajectory["done"]

        deltas = trajectory["r"] + self.GAMMA * is_nonterminals * trajectory[
            "vpred_next"] - trajectory["vpred"]

        advantages = np.zeros_like(deltas, dtype=np.float32)

        lastgae = 0
        for i in reversed(range(len(deltas))):
            lastgae = deltas[
                i] + self.GAMMA * self.GAE_LAMBDA * is_nonterminals[i] * lastgae
            advantages[i] = lastgae

        trajectory["adv"] = (advantages -
                             advantages.mean()) / (advantages.std() + 1e-8)
        #trajectory["adv"] = advantages

        trajectory["vftarget"] = trajectory["adv"] + trajectory["vpred"]

        return trajectory

    def update_policy(self, trajectory):
        def flattengrads(grads):
            flatgrads_list = [
                tf.reshape(grad, shape=[1, -1]) for grad in grads
            ]
            flatgrads = tf.concat(flatgrads_list, axis=1)
            return flatgrads

        actions = tf.convert_to_tensor(trajectory["a"], dtype=tf.float32)
        states = tf.convert_to_tensor(trajectory["s"], dtype=tf.float32)
        advantages = tf.convert_to_tensor(trajectory["adv"], dtype=tf.float32)

        old_means, old_stdevs = self.policy(states)
        old_logp = compute_logprob(old_means, old_stdevs, actions)

        with tf.GradientTape() as tape:
            new_means, new_stdevs = self.policy(states)
            new_logp = compute_logprob(new_means, new_stdevs, actions)

            loss = tf.exp(new_logp - old_logp) * advantages
            loss = tf.reduce_mean(loss)

        g = tape.gradient(loss, self.policy.trainable_variables)
        g = tf.transpose(flattengrads(g))

        @tf.function
        def hvp_func(vector):
            """Compute hessian-vector product
            """
            with tf.GradientTape() as t2:
                with tf.GradientTape() as t1:
                    new_means, new_stdevs = self.policy(states)
                    kl = compute_kl(old_means, old_stdevs, new_means,
                                    new_stdevs)
                    meankl = tf.reduce_mean(kl)

                kl_grads = t1.gradient(meankl, self.policy.trainable_variables)
                kl_grads = flattengrads(kl_grads)
                grads_vector_product = tf.matmul(kl_grads, vector)

            hvp = t2.gradient(grads_vector_product,
                              self.policy.trainable_variables)
            hvp = tf.transpose(flattengrads(hvp))

            return hvp + vector * 1e-2  #: 共役勾配法の安定化のために微小量を加える

        step_direction = cg(hvp_func, g)

        shs = tf.matmul(tf.transpose(step_direction), hvp_func(step_direction))
        lm = tf.sqrt(2 * self.MAX_KL / shs)
        fullstep = lm * step_direction

        expected_improve = tf.matmul(tf.transpose(g), fullstep)
        fullstep = restore_shape(fullstep, self.policy.trainable_variables)

        params_old = [var.numpy() for var in self.policy.trainable_variables]
        old_loss = loss

        for stepsize in [0.5**i for i in range(10)]:
            params_new = [
                p + step * stepsize for p, step in zip(params_old, fullstep)
            ]
            self.policy.set_weights(params_new)

            new_means, new_stdevs = self.policy(states)
            new_logp = compute_logprob(new_means, new_stdevs, actions)

            new_loss = tf.reduce_mean(tf.exp(new_logp - old_logp) * advantages)
            improve = new_loss - old_loss

            kl = compute_kl(old_means, old_stdevs, new_means, new_stdevs)
            mean_kl = tf.reduce_mean(kl)

            print(f"Expected: {expected_improve} Actual: {improve}")
            print(f"KL {mean_kl}")

            if mean_kl > self.MAX_KL * 1.5:
                print("violated KL constraint. shrinking step.")
            elif improve < 0:
                print("surrogate didn't improve. shrinking step.")
            else:
                print("Stepsize OK!")
                break
        else:
            print("更新に失敗")
            self.policy.set_weights(params_old)

    def update_vf(self, trajectory):

        for _ in range(self.TRAJECTORY_SIZE // self.VF_BATCHSIZE):

            indx = np.random.choice(self.TRAJECTORY_SIZE,
                                    self.VF_BATCHSIZE,
                                    replace=True)

            with tf.GradientTape() as tape:
                vpred = self.value_network(trajectory["s"][indx])
                vtarget = trajectory["vftarget"][indx]
                loss = tf.reduce_mean(tf.square(vtarget - vpred))

            variables = self.value_network.trainable_variables
            grads = tape.gradient(loss, variables)
            self.value_network.optimizer.apply_gradients(zip(grads, variables))

    def save_model(self):

        self.policy.save_weights("checkpoints/actor")

        self.value_network.save_weights("checkpoints/critic")

        print()
        print("Model Saved")
        print()

    def load_model(self):

        self.policy.load_weights("checkpoints/actor")

        self.value_network.load_weights("checkpoints/critic")

    def test_play(self, n, monitordir, load_model=False):

        if load_model:
            self.load_model()

        if monitordir:
            env = wrappers.Monitor(gym.make(self.ENV_ID),
                                   monitordir,
                                   force=True,
                                   video_callable=(lambda ep: ep % 1 == 0))
        else:
            env = gym.make(self.ENV_ID)

        for i in range(n):

            total_reward = 0

            steps = 0

            done = False

            state = env.reset()

            while not done:

                action = self.policy.sample_action(state)

                next_state, reward, done, _ = env.step(action)

                state = next_state

                total_reward += reward

                steps += 1

            print()
            print(f"Test Play {i}: {total_reward}")
            print(f"Steps:", steps)
            print()
class PPOAgent:

    GAMMA = 0.99

    GAE_LAMBDA = 0.95

    CLIPRANGE = 0.2

    OPT_ITER = 20

    BATCH_SIZE = 2048

    def __init__(self,
                 env_id,
                 action_space,
                 trajectory_size=256,
                 n_envs=1,
                 max_timesteps=1500):

        self.env_id = env_id

        self.n_envs = n_envs

        self.trajectory_size = trajectory_size

        self.vecenv = VecEnv(env_id=self.env_id,
                             n_envs=self.n_envs,
                             max_timesteps=max_timesteps)

        self.policy = PolicyNetwork(action_space=action_space)

        self.old_policy = PolicyNetwork(action_space=action_space)

        self.critic = CriticNetwork()

        self.r_running_stats = util.RunningStats(shape=(action_space, ))

        self._init_network()

    def _init_network(self):

        env = gym.make(self.env_id)

        state = np.atleast_2d(env.reset())

        self.policy(state)

        self.old_policy(state)

    def run(self, n_updates, logdir):

        self.summary_writer = tf.summary.create_file_writer(str(logdir))

        history = {"steps": [], "scores": []}

        states = self.vecenv.reset()

        hiscore = None

        for epoch in range(n_updates):

            for _ in range(self.trajectory_size):

                actions = self.policy.sample_action(states)

                next_states = self.vecenv.step(actions)

                states = next_states

            trajectories = self.vecenv.get_trajectories()

            for trajectory in trajectories:
                self.r_running_stats.update(trajectory["r"])

            trajectories = self.compute_advantage(trajectories)

            states, actions, advantages, vtargs = self.create_minibatch(
                trajectories)

            vloss = self.update_critic(states, vtargs)

            self.update_policy(states, actions, advantages)

            global_steps = (epoch + 1) * self.trajectory_size * self.n_envs
            train_scores = np.array([traj["r"].sum() for traj in trajectories])

            if epoch % 1 == 0:
                test_scores, total_steps = self.play(n=1)
                test_scores, total_steps = np.array(test_scores), np.array(
                    total_steps)
                history["steps"].append(global_steps)
                history["scores"].append(test_scores.mean())
                ma_score = sum(history["scores"][-10:]) / 10
                with self.summary_writer.as_default():
                    tf.summary.scalar("test_score",
                                      test_scores.mean(),
                                      step=epoch)
                    tf.summary.scalar("test_steps",
                                      total_steps.mean(),
                                      step=epoch)
                print(
                    f"Epoch {epoch}, {global_steps//1000}K, {test_scores.mean()}"
                )

            if epoch // 10 > 10 and (hiscore is None or ma_score > hiscore):
                self.save_model()
                hiscore = ma_score
                print("Model Saved")

            with self.summary_writer.as_default():
                tf.summary.scalar("value_loss", vloss, step=epoch)
                tf.summary.scalar("train_score",
                                  train_scores.mean(),
                                  step=epoch)

        return history

    def compute_advantage(self, trajectories):
        """
            Generalized Advantage Estimation (GAE, 2016)
        """

        for trajectory in trajectories:

            trajectory["v_pred"] = self.critic(trajectory["s"]).numpy()

            trajectory["v_pred_next"] = self.critic(trajectory["s2"]).numpy()

            is_nonterminals = 1 - trajectory["done"]

            normed_rewards = (trajectory["r"] /
                              (np.sqrt(self.r_running_stats.var) + 1e-4))

            deltas = normed_rewards + self.GAMMA * is_nonterminals * trajectory[
                "v_pred_next"] - trajectory["v_pred"]

            advantages = np.zeros_like(deltas, dtype=np.float32)

            lastgae = 0
            for i in reversed(range(len(deltas))):
                lastgae = deltas[
                    i] + self.GAMMA * self.GAE_LAMBDA * is_nonterminals[
                        i] * lastgae
                advantages[i] = lastgae

            trajectory["advantage"] = advantages

            trajectory["R"] = advantages + trajectory["v_pred"]

        return trajectories

    def update_policy(self, states, actions, advantages):

        self.old_policy.set_weights(self.policy.get_weights())

        indices = np.random.choice(range(states.shape[0]),
                                   (self.OPT_ITER, self.BATCH_SIZE))

        for i in range(self.OPT_ITER):

            idx = indices[i]

            old_means, old_stdevs = self.old_policy(states[idx])

            old_logprob = self.compute_logprob(old_means, old_stdevs,
                                               actions[idx])

            with tf.GradientTape() as tape:

                new_means, new_stdevs = self.policy(states[idx])

                new_logprob = self.compute_logprob(new_means, new_stdevs,
                                                   actions[idx])

                ratio = tf.exp(new_logprob - old_logprob)

                ratio_clipped = tf.clip_by_value(ratio, 1 - self.CLIPRANGE,
                                                 1 + self.CLIPRANGE)

                loss_unclipped = ratio * advantages[idx]

                loss_clipped = ratio_clipped * advantages[idx]

                loss = tf.minimum(loss_unclipped, loss_clipped)

                loss = -1 * tf.reduce_mean(loss)

            grads = tape.gradient(loss, self.policy.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 0.5)
            self.policy.optimizer.apply_gradients(
                zip(grads, self.policy.trainable_variables))

    def update_critic(self, states, v_targs):

        losses = []

        indices = np.random.choice(range(states.shape[0]),
                                   (self.OPT_ITER, self.BATCH_SIZE))

        for i in range(self.OPT_ITER):

            idx = indices[i]

            old_vpred = self.critic(states[idx])

            with tf.GradientTape() as tape:

                vpred = self.critic(states[idx])

                vpred_clipped = old_vpred + tf.clip_by_value(
                    vpred - old_vpred, -self.CLIPRANGE, self.CLIPRANGE)

                loss = tf.maximum(tf.square(v_targs[idx] - vpred),
                                  tf.square(v_targs[idx] - vpred_clipped))

                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss, self.critic.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 0.5)
            self.critic.optimizer.apply_gradients(
                zip(grads, self.critic.trainable_variables))

            losses.append(loss)

        return np.array(losses).mean()

    @tf.function
    def compute_logprob(self, means, stdevs, actions):
        """ガウス分布の確率密度関数よりlogp(x)を計算
            logp(x) = -0.5 log(2π) - log(std)  -0.5 * ((x - mean) / std )^2
        """
        logprob = -0.5 * np.log(2 * np.pi)
        logprob += -tf.math.log(stdevs)
        logprob += -0.5 * tf.square((actions - means) / stdevs)
        logprob = tf.reduce_sum(logprob, axis=1, keepdims=True)
        return logprob

    def create_minibatch(self, trajectories):

        states = np.vstack([traj["s"] for traj in trajectories])
        actions = np.vstack([traj["a"] for traj in trajectories])

        advantages = np.vstack([traj["advantage"] for traj in trajectories])

        v_targs = np.vstack([traj["R"] for traj in trajectories])

        return states, actions, advantages, v_targs

    def save_model(self):

        self.policy.save_weights("checkpoints/policy")

        self.critic.save_weights("checkpoints/critic")

    def load_model(self):

        self.policy.load_weights("checkpoints/policy")

        self.critic.load_weights("checkpoints/critic")

    def play(self, n=1, monitordir=None, verbose=False):

        if monitordir:
            env = wrappers.Monitor(gym.make(self.env_id),
                                   monitordir,
                                   force=True,
                                   video_callable=(lambda ep: True))
        else:
            env = gym.make(self.env_id)

        total_rewards = []
        total_steps = []

        for _ in range(n):

            state = env.reset()

            done = False

            total_reward = 0

            steps = 0

            while not done:

                steps += 1

                action = self.policy.sample_action(state)

                next_state, reward, done, _ = env.step(action[0])

                if verbose:
                    mean, sd = self.policy(np.atleast_2d(state))
                    print(mean, sd)
                    print(reward)

                total_reward += reward

                if done:
                    break
                else:
                    state = next_state

            total_rewards.append(total_reward)
            total_steps.append(steps)
            print()
            print(total_reward, steps)
            print()

        return total_rewards, total_steps