def main(num_cpus,
         n_episodes=10000,
         buffer_size=40000,
         batch_size=64,
         epochs_per_update=5,
         num_mcts_simulations=50,
         update_period=300,
         test_period=300,
         n_testplay=20,
         save_period=3000,
         dirichlet_alpha=0.35):

    ray.init(num_cpus=num_cpus, num_gpus=1, local_mode=False)

    logdir = Path(__file__).parent / "log"
    if logdir.exists():
        shutil.rmtree(logdir)
    summary_writer = tf.summary.create_file_writer(str(logdir))

    network = AlphaZeroResNet(action_space=othello.ACTION_SPACE)

    #: initialize network parameters
    dummy_state = othello.encode_state(othello.get_initial_state(), 1)

    network.predict(dummy_state)

    current_weights = ray.put(network.get_weights())

    #optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=0.9)
    optimizer = tf.keras.optimizers.Adam(lr=0.0005)

    replay = ReplayBuffer(buffer_size=buffer_size)

    #: 並列Selfplay
    work_in_progresses = [
        selfplay.remote(current_weights, num_mcts_simulations, dirichlet_alpha)
        for _ in range(num_cpus - 2)
    ]

    test_in_progress = testplay.remote(current_weights,
                                       num_mcts_simulations,
                                       n_testplay=n_testplay)

    n_updates = 0
    n = 0
    while n <= n_episodes:

        for _ in tqdm(range(update_period)):
            #: selfplayが終わったプロセスを一つ取得
            finished, work_in_progresses = ray.wait(work_in_progresses,
                                                    num_returns=1)
            replay.add_record(ray.get(finished[0]))
            work_in_progresses.extend([
                selfplay.remote(current_weights, num_mcts_simulations,
                                dirichlet_alpha)
            ])
            n += 1

        #: Update network
        if len(replay) >= 20000:
            #if len(replay) >= 2000:

            num_iters = epochs_per_update * (len(replay) // batch_size)
            for i in range(num_iters):

                states, mcts_policy, rewards = replay.get_minibatch(
                    batch_size=batch_size)

                with tf.GradientTape() as tape:

                    p_pred, v_pred = network(states, training=True)
                    value_loss = tf.square(rewards - v_pred)

                    policy_loss = -mcts_policy * tf.math.log(p_pred + 0.0001)
                    policy_loss = tf.reduce_sum(policy_loss,
                                                axis=1,
                                                keepdims=True)

                    loss = tf.reduce_mean(value_loss + policy_loss)

                grads = tape.gradient(loss, network.trainable_variables)
                optimizer.apply_gradients(
                    zip(grads, network.trainable_variables))

                n_updates += 1

                if i % 100 == 0:
                    with summary_writer.as_default():
                        tf.summary.scalar("v_loss",
                                          value_loss.numpy().mean(),
                                          step=n_updates)
                        tf.summary.scalar("p_loss",
                                          policy_loss.numpy().mean(),
                                          step=n_updates)

            current_weights = ray.put(network.get_weights())

        if n % test_period == 0:
            print(f"{n - test_period}: TEST")
            win_count, win_ratio, elapsed_time = ray.get(test_in_progress)
            print(f"SCORE: {win_count}, {win_ratio}, Elapsed: {elapsed_time}")
            test_in_progress = testplay.remote(current_weights,
                                               num_mcts_simulations,
                                               n_testplay=n_testplay)

            with summary_writer.as_default():
                tf.summary.scalar("win_count", win_count, step=n - test_period)
                tf.summary.scalar("win_ratio", win_ratio, step=n - test_period)
                tf.summary.scalar("buffer_size", len(replay), step=n)

        if n % save_period == 0:
            network.save_weights("checkpoints/network")
class DQNAgent:
    def __init__(self,
                 env_name="BreakoutDeterministic-v4",
                 gamma=0.99,
                 batch_size=32,
                 lr=0.00025,
                 update_period=4,
                 target_update_period=10000,
                 n_frames=4):

        self.env_name = env_name

        self.gamma = gamma

        self.batch_size = batch_size

        self.epsilon_scheduler = (
            lambda steps: max(1.0 - 0.9 * steps / 1000000, 0.1))

        self.update_period = update_period

        self.target_update_period = target_update_period

        env = gym.make(self.env_name)

        self.action_space = env.action_space.n

        self.qnet = QNetwork(self.action_space)

        self.target_qnet = QNetwork(self.action_space)

        self.optimizer = Adam(lr=lr, epsilon=0.01 / self.batch_size)

        self.n_frames = n_frames

        self.use_reward_clipping = True

        self.huber_loss = tf.keras.losses.Huber()

    def learn(self, n_episodes, buffer_size=1000000, logdir="log"):

        logdir = Path(__file__).parent / logdir
        if logdir.exists():
            shutil.rmtree(logdir)
        self.summary_writer = tf.summary.create_file_writer(str(logdir))

        self.replay_buffer = ReplayBuffer(max_len=buffer_size)

        steps = 0
        for episode in range(1, n_episodes + 1):
            env = gym.make(self.env_name)

            frame = preprocess_frame(env.reset())
            frames = collections.deque([frame] * self.n_frames,
                                       maxlen=self.n_frames)

            episode_rewards = 0
            episode_steps = 0
            done = False
            lives = 5

            while not done:

                steps, episode_steps = steps + 1, episode_steps + 1

                epsilon = self.epsilon_scheduler(steps)

                state = np.stack(frames, axis=2)[np.newaxis, ...]

                action = self.qnet.sample_action(state, epsilon=epsilon)

                next_frame, reward, done, info = env.step(action)

                episode_rewards += reward

                frames.append(preprocess_frame(next_frame))

                next_state = np.stack(frames, axis=2)[np.newaxis, ...]

                if info["ale.lives"] != lives:
                    lives = info["ale.lives"]
                    transition = (state, action, reward, next_state, True)
                else:
                    transition = (state, action, reward, next_state, done)

                self.replay_buffer.push(transition)

                if len(self.replay_buffer) > 50000:
                    if steps % self.update_period == 0:
                        loss = self.update_network()
                        with self.summary_writer.as_default():
                            tf.summary.scalar("loss", loss, step=steps)
                            tf.summary.scalar("epsilon", epsilon, step=steps)
                            tf.summary.scalar("buffer_size",
                                              len(self.replay_buffer),
                                              step=steps)
                            tf.summary.scalar("train_score",
                                              episode_rewards,
                                              step=steps)
                            tf.summary.scalar("train_steps",
                                              episode_steps,
                                              step=steps)

                    if steps % self.target_update_period == 0:
                        self.target_qnet.set_weights(self.qnet.get_weights())

                if done:
                    break

            print(
                f"Episode: {episode}, score: {episode_rewards}, steps: {episode_steps}"
            )
            if episode % 20 == 0:
                test_scores, test_steps = self.test_play(n_testplay=1)
                with self.summary_writer.as_default():
                    tf.summary.scalar("test_score", test_scores[0], step=steps)
                    tf.summary.scalar("test_step", test_steps[0], step=steps)

            if episode % 1000 == 0:
                self.qnet.save_weights("checkpoints/qnet")

    def update_network(self):

        #: ミニバッチの作成
        (states, actions, rewards, next_states,
         dones) = self.replay_buffer.get_minibatch(self.batch_size)

        if self.use_reward_clipping:
            rewards = np.clip(rewards, -1, 1)

        next_actions, next_qvalues = self.target_qnet.sample_actions(
            next_states)
        next_actions_onehot = tf.one_hot(next_actions, self.action_space)
        max_next_qvalues = tf.reduce_sum(next_qvalues * next_actions_onehot,
                                         axis=1,
                                         keepdims=True)

        target_q = rewards + self.gamma * (1 - dones) * max_next_qvalues

        with tf.GradientTape() as tape:

            qvalues = self.qnet(states)
            actions_onehot = tf.one_hot(actions.flatten().astype(np.int32),
                                        self.action_space)
            q = tf.reduce_sum(qvalues * actions_onehot, axis=1, keepdims=True)
            loss = self.huber_loss(target_q, q)

        grads = tape.gradient(loss, self.qnet.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.qnet.trainable_variables))

        return loss

    def test_play(self, n_testplay=1, monitor_dir=None, checkpoint_path=None):

        if checkpoint_path:
            env = gym.make(self.env_name)
            frame = preprocess_frame(env.reset())
            frames = collections.deque([frame] * self.n_frames,
                                       maxlen=self.n_frames)

            state = np.stack(frames, axis=2)[np.newaxis, ...]
            self.qnet(state)
            self.qnet.load_weights(checkpoint_path)

        if monitor_dir:
            monitor_dir = Path(monitor_dir)
            if monitor_dir.exists():
                shutil.rmtree(monitor_dir)
            monitor_dir.mkdir()
            env = gym.wrappers.Monitor(gym.make(self.env_name),
                                       monitor_dir,
                                       force=True,
                                       video_callable=(lambda ep: True))
        else:
            env = gym.make(self.env_name)

        scores = []
        steps = []
        for _ in range(n_testplay):

            frame = preprocess_frame(env.reset())
            frames = collections.deque([frame] * self.n_frames,
                                       maxlen=self.n_frames)

            done = False
            episode_steps = 0
            episode_rewards = 0

            while not done:
                state = np.stack(frames, axis=2)[np.newaxis, ...]
                action = self.qnet.sample_action(state, epsilon=0.05)
                next_frame, reward, done, _ = env.step(action)
                frames.append(preprocess_frame(next_frame))

                episode_rewards += reward
                episode_steps += 1
                if episode_steps > 500 and episode_rewards < 3:
                    #: ゲーム開始(action: 0)しないまま停滞するケースへの対処
                    break

            scores.append(episode_rewards)
            steps.append(episode_steps)

        return scores, steps
Exemplo n.º 3
0
class TD3Agent:

    MAX_EXPERIENCES = 30000

    MIN_EXPERIENCES = 300

    ENV_ID = "Pendulum-v0"

    ACTION_SPACE = 1

    MAX_ACTION = 2

    OBSERVATION_SPACE = 3

    CRITIC_UPDATE_PERIOD = 4

    POLICY_UPDATE_PERIOD = 8

    TAU = 0.02

    GAMMA = 0.99

    BATCH_SIZE = 64

    NOISE_STDDEV = 0.2

    def __init__(self):

        self.env = gym.make(self.ENV_ID)

        self.env.max_episode_steps = 3000

        self.actor = ActorNetwork(action_space=self.ACTION_SPACE,
                                  max_action=self.MAX_ACTION)

        self.target_actor = ActorNetwork(action_space=self.ACTION_SPACE,
                                         max_action=self.MAX_ACTION)

        self.critic = CriticNetwork()

        self.target_critic = CriticNetwork()

        self.buffer = ReplayBuffer(max_experiences=self.MAX_EXPERIENCES)

        self.global_steps = 0

        self.hiscore = None

        self._build_networks()

    def _build_networks(self):
        """パラメータの初期化
        """

        dummy_state = np.random.normal(0, 0.1, size=self.OBSERVATION_SPACE)
        dummy_state = (dummy_state[np.newaxis, ...]).astype(np.float32)

        dummy_action = np.random.normal(0, 0.1, size=self.ACTION_SPACE)
        dummy_action = (dummy_action[np.newaxis, ...]).astype(np.float32)

        self.actor.call(dummy_state)
        self.target_actor.call(dummy_state)
        self.target_actor.set_weights(self.actor.get_weights())

        self.critic.call(dummy_state, dummy_action, training=False)
        self.target_critic.call(dummy_state, dummy_action, training=False)
        self.target_critic.set_weights(self.critic.get_weights())

    def play(self, n_episodes):

        total_rewards = []

        recent_scores = collections.deque(maxlen=10)

        for n in range(n_episodes):

            total_reward, localsteps = self.play_episode()

            total_rewards.append(total_reward)

            recent_scores.append(total_reward)

            recent_average_score = sum(recent_scores) / len(recent_scores)

            print(f"Episode {n}: {total_reward}")
            print(f"Local steps {localsteps}")
            print(f"Experiences {len(self.buffer)}")
            print(f"Global step {self.global_steps}")
            print(f"Noise stdev {self.NOISE_STDDEV}")
            print(f"recent average score {recent_average_score}")
            print()

            if (self.hiscore is None) or (recent_average_score > self.hiscore):
                self.hiscore = recent_average_score
                print(f"HISCORE Updated: {self.hiscore}")
                self.save_model()

        return total_rewards

    def play_episode(self):

        total_reward = 0

        steps = 0

        done = False

        state = self.env.reset()

        while not done:

            action = self.actor.sample_action(state, noise=self.NOISE_STDDEV)

            next_state, reward, done, _ = self.env.step(action)

            exp = Experience(state, action, reward, next_state, done)

            self.buffer.add_experience(exp)

            state = next_state

            total_reward += reward

            steps += 1

            self.global_steps += 1

            #: Delayed Policy update
            if self.global_steps % self.CRITIC_UPDATE_PERIOD == 0:
                if self.global_steps % self.POLICY_UPDATE_PERIOD == 0:
                    self.update_network(self.BATCH_SIZE, update_policy=True)
                    self.update_target_network()
                else:
                    self.update_network(self.BATCH_SIZE)

        return total_reward, steps

    def update_network(self, batch_size, update_policy=False):

        if len(self.buffer) < self.MIN_EXPERIENCES:
            return

        (states, actions, rewards, next_states,
         dones) = self.buffer.get_minibatch(batch_size)

        clipped_noise = np.clip(np.random.normal(0, 0.2, self.ACTION_SPACE),
                                -0.5, 0.5)

        next_actions = self.target_actor(
            next_states) + clipped_noise * self.MAX_ACTION

        q1, q2 = self.target_critic(next_states, next_actions)

        next_qvalues = [
            min(q1, q2) for q1, q2 in zip(q1.numpy().flatten(),
                                          q2.numpy().flatten())
        ]

        #: Compute taeget values and update CriticNetwork
        target_values = np.vstack([
            reward + self.GAMMA * next_qvalue if not done else reward
            for reward, done, next_qvalue in zip(rewards, dones, next_qvalues)
        ]).astype(np.float32)

        #: Update Critic
        with tf.GradientTape() as tape:
            q1, q2 = self.critic(states, actions)
            loss1 = tf.reduce_mean(tf.square(target_values - q1))
            loss2 = tf.reduce_mean(tf.square(target_values - q2))
            loss = loss1 + loss2

        variables = self.critic.trainable_variables
        gradients = tape.gradient(loss, variables)
        self.critic.optimizer.apply_gradients(zip(gradients, variables))

        #: Delayed Update ActorNetwork
        if update_policy:

            with tf.GradientTape() as tape:
                q1, _ = self.critic(states, self.actor(states))
                J = -1 * tf.reduce_mean(q1)

            variables = self.actor.trainable_variables
            gradients = tape.gradient(J, variables)
            self.actor.optimizer.apply_gradients(zip(gradients, variables))

    def update_target_network(self):

        # soft-target update Actor
        target_actor_weights = self.target_actor.get_weights()
        actor_weights = self.actor.get_weights()

        assert len(target_actor_weights) == len(actor_weights)

        self.target_actor.set_weights((1 - self.TAU) *
                                      np.array(target_actor_weights) +
                                      (self.TAU) * np.array(actor_weights))

        # soft-target update Critic
        target_critic_weights = self.target_critic.get_weights()
        critic_weights = self.critic.get_weights()

        assert len(target_critic_weights) == len(critic_weights)

        self.target_critic.set_weights((1 - self.TAU) *
                                       np.array(target_critic_weights) +
                                       (self.TAU) * np.array(critic_weights))

    def save_model(self):

        self.actor.save_weights("checkpoints/actor")

        self.critic.save_weights("checkpoints/critic")

    def load_model(self):

        self.actor.load_weights("checkpoints/actor")

        self.target_actor.load_weights("checkpoints/actor")

        self.critic.load_weights("checkpoints/critic")

        self.target_critic.load_weights("checkpoints/critic")

    def test_play(self, n, monitordir, load_model=False):

        if load_model:
            self.load_model()

        if monitordir:
            env = wrappers.Monitor(gym.make(self.ENV_ID),
                                   monitordir,
                                   force=True,
                                   video_callable=(lambda ep: ep % 1 == 0))
        else:
            env = gym.make(self.ENV_ID)

        for i in range(n):

            total_reward = 0

            steps = 0

            done = False

            state = env.reset()

            while not done:

                action = self.actor.sample_action(state, noise=False)

                next_state, reward, done, _ = env.step(action)

                state = next_state

                total_reward += reward

                steps += 1

            print()
            print(f"Test Play {i}: {total_reward}")
            print(f"Steps:", steps)
            print()
class DDPGAgent:

    MAX_EXPERIENCES = 30000

    MIN_EXPERIENCES = 300

    ENV_ID = "Pendulum-v0"

    ACTION_SPACE = 1

    OBSERVATION_SPACE = 3

    UPDATE_PERIOD = 4

    START_EPISODES = 20

    TAU = 0.02

    GAMMA = 0.99

    BATCH_SIZE = 32

    def __init__(self):

        self.env = gym.make(self.ENV_ID)

        self.env.max_episode_steps = 1000

        self.actor_network = ActorNetwork(action_space=self.ACTION_SPACE)

        self.target_actor_network = ActorNetwork(
            action_space=self.ACTION_SPACE)

        self.critic_network = CriticNetwork()

        self.target_critic_network = CriticNetwork()

        self.stdev = 0.2

        self.buffer = ReplayBuffer(max_experiences=self.MAX_EXPERIENCES)

        self.global_steps = 0

        self.hiscore = None

        self._build_networks()

    def _build_networks(self):
        """パラメータの初期化
        """

        dummy_state = np.random.normal(0, 0.1, size=self.OBSERVATION_SPACE)
        dummy_state = (dummy_state[np.newaxis, ...]).astype(np.float32)

        dummy_action = np.random.normal(0, 0.1, size=self.ACTION_SPACE)
        dummy_action = (dummy_action[np.newaxis, ...]).astype(np.float32)

        self.actor_network.call(dummy_state)
        self.target_actor_network.call(dummy_state)
        self.target_actor_network.set_weights(self.actor_network.get_weights())

        self.critic_network.call(dummy_state, dummy_action, training=False)
        self.target_critic_network.call(dummy_state,
                                        dummy_action,
                                        training=False)
        self.target_critic_network.set_weights(
            self.critic_network.get_weights())

    def play(self, n_episodes):

        total_rewards = []

        recent_scores = collections.deque(maxlen=10)

        for n in range(n_episodes):

            if n <= self.START_EPISODES:
                total_reward, localsteps = self.play_episode(random=True)
            else:
                total_reward, localsteps = self.play_episode()

            total_rewards.append(total_reward)

            recent_scores.append(total_reward)

            recent_average_score = sum(recent_scores) / len(recent_scores)

            print(f"Episode {n}: {total_reward}")
            print(f"Local steps {localsteps}")
            print(f"Experiences {len(self.buffer)}")
            print(f"Global step {self.global_steps}")
            print(f"Noise stdev {self.stdev}")
            print(f"recent average score {recent_average_score}")
            print()

            if (self.hiscore is None) or (recent_average_score > self.hiscore):
                self.hiscore = recent_average_score
                print(f"HISCORE Updated: {self.hiscore}")
                self.save_model()

        return total_rewards

    def play_episode(self, random=False):

        total_reward = 0

        steps = 0

        done = False

        state = self.env.reset()

        while not done:

            if random:
                action = np.random.uniform(-2, 2, size=self.ACTION_SPACE)
            else:
                action = self.actor_network.sample_action(state,
                                                          noise=self.stdev)

            next_state, reward, done, _ = self.env.step(action)

            exp = Experience(state, action, reward, next_state, done)

            self.buffer.add_experience(exp)

            state = next_state

            total_reward += reward

            steps += 1

            self.global_steps += 1

            if self.global_steps % self.UPDATE_PERIOD == 0:
                self.update_network(self.BATCH_SIZE)
                self.update_target_network()

        return total_reward, steps

    def update_network(self, batch_size):

        if len(self.buffer) < self.MIN_EXPERIENCES:
            return

        (states, actions, rewards, next_states,
         dones) = self.buffer.get_minibatch(batch_size)

        next_actions = self.target_actor_network(next_states)

        next_qvalues = self.target_critic_network(
            next_states, next_actions).numpy().flatten()

        #: Compute taeget values and update CriticNetwork
        target_values = np.vstack([
            reward + self.GAMMA * next_qvalue if not done else reward
            for reward, done, next_qvalue in zip(rewards, dones, next_qvalues)
        ]).astype(np.float32)

        with tf.GradientTape() as tape:
            qvalues = self.critic_network(states, actions)
            loss = tf.reduce_mean(tf.square(target_values - qvalues))

        variables = self.critic_network.trainable_variables
        gradients = tape.gradient(loss, variables)
        self.critic_network.optimizer.apply_gradients(zip(
            gradients, variables))

        #: Update ActorNetwork
        with tf.GradientTape() as tape:
            J = -1 * tf.reduce_mean(
                self.critic_network(states, self.actor_network(states)))

        variables = self.actor_network.trainable_variables
        gradients = tape.gradient(J, variables)
        self.actor_network.optimizer.apply_gradients(zip(gradients, variables))

    def update_target_network(self):

        # soft-target update Actor
        target_actor_weights = self.target_actor_network.get_weights()
        actor_weights = self.actor_network.get_weights()

        assert len(target_actor_weights) == len(actor_weights)

        self.target_actor_network.set_weights(
            (1 - self.TAU) * np.array(target_actor_weights) +
            (self.TAU) * np.array(actor_weights))

        # soft-target update Critic
        target_critic_weights = self.target_critic_network.get_weights()
        critic_weights = self.critic_network.get_weights()

        assert len(target_critic_weights) == len(critic_weights)

        self.target_critic_network.set_weights(
            (1 - self.TAU) * np.array(target_critic_weights) +
            (self.TAU) * np.array(critic_weights))

    def save_model(self):

        self.actor_network.save_weights("checkpoints/actor")

        self.critic_network.save_weights("checkpoints/critic")

    def load_model(self):

        self.actor_network.load_weights("checkpoints/actor")

        self.target_actor_network.load_weights("checkpoints/actor")

        self.critic_network.load_weights("checkpoints/critic")

        self.target_critic_network.load_weights("checkpoints/critic")

    def test_play(self, n, monitordir, load_model=False):

        if load_model:
            self.load_model()

        if monitordir:
            env = wrappers.Monitor(gym.make(self.ENV_ID),
                                   monitordir,
                                   force=True,
                                   video_callable=(lambda ep: ep % 1 == 0))
        else:
            env = gym.make(self.ENV_ID)

        for i in range(n):

            total_reward = 0

            steps = 0

            done = False

            state = env.reset()

            while not done:

                action = self.actor_network.sample_action(state, noise=False)

                next_state, reward, done, _ = env.step(action)

                state = next_state

                total_reward += reward

                steps += 1

            print()
            print(f"Test Play {i}: {total_reward}")
            print(f"Steps:", steps)
            print()
class MPOAgent:
    def __init__(self, env_id: str, logdir: Path):

        self.env_id = env_id

        self.summary_writer = tf.summary.create_file_writer(
            str(logdir)) if logdir else None

        self.action_space = gym.make(self.env_id).action_space.shape[0]

        self.replay_buffer = ReplayBuffer(maxlen=10000)

        self.policy = GaussianPolicyNetwork(action_space=self.action_space)
        self.target_policy = GaussianPolicyNetwork(
            action_space=self.action_space)

        self.critic = QNetwork()
        self.target_critic = QNetwork()

        self.log_temperature = tf.Variable(1.)

        self.log_alpha_mu = tf.Variable(1.)
        self.log_alpha_sigma = tf.Variable(1.)

        self.eps = 0.1

        self.eps_mu = 0.01
        self.eps_sigma = 0.001

        self.policy_optimizer = tf.keras.optimizers.Adam(lr=0.0005)
        self.critic_optimizer = tf.keras.optimizers.Adam(lr=0.0005)
        self.temperature_optimizer = tf.keras.optimizers.Adam(lr=0.0005)
        self.alpha_optimizer = tf.keras.optimizers.Adam(lr=0.0005)

        self.batch_size = 128

        self.n_samples = 10

        self.update_period = 4

        self.gamma = 0.99

        self.target_policy_update_period = 400

        self.target_critic_update_period = 400

        self.global_steps = 0

        self.episode_count = 0

        self.setup()

    def setup(self):
        """ Initialize network weights """

        env = gym.make(self.env_id)

        dummy_state = env.reset()
        dummy_state = (dummy_state[np.newaxis, ...]).astype(np.float32)

        dummy_action = np.random.normal(0, 0.1, size=self.action_space)
        dummy_action = (dummy_action[np.newaxis, ...]).astype(np.float32)

        self.policy(dummy_state)
        self.target_policy(dummy_state)

        self.critic(dummy_state, dummy_action)
        self.target_critic(dummy_state, dummy_action)

        self.target_policy.set_weights(self.policy.get_weights())
        self.target_critic.set_weights(self.critic.get_weights())

    def save(self, save_dir):
        save_dir = Path(save_dir)

        self.policy.save_weights(str(save_dir / "policy"))
        self.critic.save_weights(str(save_dir / "critic"))

    def load(self, load_dir=None):
        load_dir = Path(load_dir)

        self.policy.load_weights(str(load_dir / "policy"))
        self.target_policy.load_weights(str(load_dir / "policy"))

        self.critic.load_weights(str(load_dir / "critic"))
        self.target_critic.load_weights(str(load_dir / "critic"))

    def rollout(self):

        episode_rewards, episode_steps = 0, 0

        done = False

        env = gym.make(self.env_id)

        state = env.reset()

        while not done:

            action = self.policy.sample_action(np.atleast_2d(state))

            action = action.numpy()[0]

            try:
                next_state, reward, done, _ = env.step(action)
            except Exception as err:
                print(err)
                import pdb
                pdb.set_trace()

            #: Bipedalwalkerの転倒ペナルティ-100は大きすぎるためclip
            transition = Transition(state, action, np.clip(reward, -1., 1.),
                                    next_state, done)

            self.replay_buffer.add(transition)

            state = next_state

            episode_rewards += reward

            episode_steps += 1

            self.global_steps += 1

            if (len(self.replay_buffer) >= 5000
                    and self.global_steps % self.update_period == 0):
                self.update_networks()

            if self.global_steps % self.target_critic_update_period == 0:
                self.target_critic.set_weights(self.critic.get_weights())

            if self.global_steps % self.target_policy_update_period == 0:
                self.target_policy.set_weights(self.policy.get_weights())

        self.episode_count += 1
        with self.summary_writer.as_default():
            tf.summary.scalar("episode_reward_stp",
                              episode_rewards,
                              step=self.global_steps)
            tf.summary.scalar("episode_steps_stp",
                              episode_steps,
                              step=self.global_steps)
            tf.summary.scalar("episode_reward",
                              episode_rewards,
                              step=self.episode_count)
            tf.summary.scalar("episode_steps",
                              episode_steps,
                              step=self.episode_count)

        return episode_rewards, episode_steps

    def update_networks(self):

        (states, actions, rewards, next_states,
         dones) = self.replay_buffer.get_minibatch(batch_size=self.batch_size)

        B, M = self.batch_size, self.n_samples

        # [B, obs_dim] -> [B, obs_dim * M] -> [B * M, obs_dim]
        next_states_tiled = tf.reshape(tf.tile(next_states, multiples=(1, M)),
                                       shape=(B * M, -1))

        target_mu, target_sigma = self.target_policy(next_states_tiled)

        # For MultivariateGaussianPolicy
        #target_dist = tfd.MultivariateNormalFullCovariance(loc=target_mu, covariance_matrix=target_sigma)

        # For IndependentGaussianPolicy
        target_dist = tfd.Independent(tfd.Normal(loc=target_mu,
                                                 scale=target_sigma),
                                      reinterpreted_batch_ndims=1)

        sampled_actions = target_dist.sample()  # [B * M,  action_dim]
        #sampled_actions = tf.clip_by_value(sampled_actions, -1.0, 1.0)

        # Update Q-network:
        sampled_qvalues = tf.reshape(self.target_critic(
            next_states_tiled, sampled_actions),
                                     shape=(B, M, -1))
        mean_qvalues = tf.reduce_mean(sampled_qvalues, axis=1)
        TQ = rewards + self.gamma * (1.0 - dones) * mean_qvalues

        with tf.GradientTape() as tape1:
            Q = self.critic(states, actions)
            loss_critic = tf.reduce_mean(tf.square(TQ - Q))

        variables = self.critic.trainable_variables
        grads = tape1.gradient(loss_critic, variables)
        grads, _ = tf.clip_by_global_norm(grads, 40.)
        self.critic_optimizer.apply_gradients(zip(grads, variables))

        # E-step:
        # Obtain η* by minimising g(η)
        with tf.GradientTape() as tape2:
            temperature = tf.math.softplus(self.log_temperature)
            q_logsumexp = tf.math.reduce_logsumexp(sampled_qvalues /
                                                   temperature,
                                                   axis=1)
            loss_temperature = temperature * (
                self.eps + tf.reduce_mean(q_logsumexp, axis=0))

        grad = tape2.gradient(loss_temperature, self.log_temperature)
        if tf.math.is_nan(grad).numpy().sum() != 0:
            print("NAN GRAD in TEMPERATURE !!!!!!!!!")
            import pdb
            pdb.set_trace()
        else:
            self.temperature_optimizer.apply_gradients([
                (grad, self.log_temperature)
            ])

        # Obtain sample-based variational distribution q(a|s)
        temperature = tf.math.softplus(self.log_temperature)

        # M-step: Optimize the lower bound J with respect to θ
        weights = tf.squeeze(tf.math.softmax(sampled_qvalues / temperature,
                                             axis=1),
                             axis=2)  # [B, M, 1]

        if tf.math.is_nan(weights).numpy().sum() != 0:
            print("NAN in weights !!!!!!!!!")
            import pdb
            pdb.set_trace()

        with tf.GradientTape(persistent=True) as tape3:

            online_mu, online_sigma = self.policy(next_states_tiled)

            # For MultivariateGaussianPolicy
            #online_dist = tfd.MultivariateNormalFullCovariance(loc=online_mu, covariance_matrix=online_sigma)

            # For IndependentGaussianPolicy
            online_dist = tfd.Independent(tfd.Normal(loc=online_mu,
                                                     scale=online_sigma),
                                          reinterpreted_batch_ndims=1)

            log_probs = tf.reshape(online_dist.log_prob(sampled_actions) +
                                   1e-6,
                                   shape=(B, M))  # [B * M, ] -> [B, M]

            cross_entropy_qp = tf.reduce_sum(weights * log_probs,
                                             axis=1)  # [B, M] -> [B,]

            # For MultivariateGaussianPolicy
            # online_dist_fixedmu = tfd.MultivariateNormalFullCovariance(loc=target_mu, covariance_matrix=online_sigma)
            # online_dist_fixedsigma = tfd.MultivariateNormalFullCovariance(loc=online_mu, covariance_matrix=target_sigma)

            # For IndependentGaussianPolicy
            online_dist_fixedmu = tfd.Independent(tfd.Normal(
                loc=target_mu, scale=online_sigma),
                                                  reinterpreted_batch_ndims=1)
            online_dist_fixedsigma = tfd.Independent(
                tfd.Normal(loc=online_mu, scale=target_sigma),
                reinterpreted_batch_ndims=1)

            kl_mu = tf.reshape(
                target_dist.kl_divergence(online_dist_fixedsigma),
                shape=(B, M))  # [B * M, ] -> [B, M]

            kl_sigma = tf.reshape(
                target_dist.kl_divergence(online_dist_fixedmu),
                shape=(B, M))  # [B * M, ] -> [B, M]

            alpha_mu = tf.math.softplus(self.log_alpha_mu)
            alpha_sigma = tf.math.softplus(self.log_alpha_sigma)

            loss_policy = -cross_entropy_qp  # [B,]
            loss_policy += tf.stop_gradient(alpha_mu) * tf.reduce_mean(kl_mu,
                                                                       axis=1)
            loss_policy += tf.stop_gradient(alpha_sigma) * tf.reduce_mean(
                kl_sigma, axis=1)

            loss_policy = tf.reduce_mean(loss_policy)  # [B,] -> [1]

            loss_alpha_mu = tf.reduce_mean(
                alpha_mu *
                tf.stop_gradient(self.eps_mu - tf.reduce_mean(kl_mu, axis=1)))

            loss_alpha_sigma = tf.reduce_mean(
                alpha_sigma *
                tf.stop_gradient(self.eps_sigma -
                                 tf.reduce_mean(kl_sigma, axis=1)))

            loss_alpha = loss_alpha_mu + loss_alpha_sigma

        variables = self.policy.trainable_variables
        grads = tape3.gradient(loss_policy, variables)
        grads, _ = tf.clip_by_global_norm(grads, 40.)
        self.policy_optimizer.apply_gradients(zip(grads, variables))

        variables = [self.log_alpha_mu, self.log_alpha_sigma]
        grads = tape3.gradient(loss_alpha, variables)
        grads, _ = tf.clip_by_global_norm(grads, 40.)
        self.alpha_optimizer.apply_gradients(zip(grads, variables))

        del tape3

        with self.summary_writer.as_default():
            tf.summary.scalar("loss_policy",
                              loss_policy,
                              step=self.global_steps)
            tf.summary.scalar("loss_critic",
                              loss_critic,
                              step=self.global_steps)
            tf.summary.scalar("sigma",
                              tf.reduce_mean(online_sigma),
                              step=self.global_steps)
            tf.summary.scalar("kl_mu",
                              tf.reduce_mean(kl_mu),
                              step=self.global_steps)
            tf.summary.scalar("kl_sigma",
                              tf.reduce_mean(kl_sigma),
                              step=self.global_steps)
            tf.summary.scalar("temperature",
                              temperature,
                              step=self.global_steps)
            tf.summary.scalar("alpha_mu", alpha_mu, step=self.global_steps)
            tf.summary.scalar("alpha_sigma",
                              alpha_sigma,
                              step=self.global_steps)
            tf.summary.scalar("replay_buffer",
                              len(self.replay_buffer),
                              step=self.global_steps)

    def testplay(self, name, monitor_dir):

        total_rewards = []

        env = wrappers.RecordVideo(gym.make(self.env_id),
                                   video_folder=monitor_dir,
                                   step_trigger=lambda i: True,
                                   name_prefix=name)

        state = env.reset()

        done = False

        total_reward = 0

        while not done:

            action = self.policy.sample_action(np.atleast_2d(state))

            action = action.numpy()[0]

            next_state, reward, done, _ = env.step(action)

            total_reward += reward

            state = next_state

        total_rewards.append(total_reward)

        print(f"{name}", total_reward)
class SAC:

    MAX_EXPERIENCES = 100000

    MIN_EXPERIENCES = 512

    UPDATE_PERIOD = 4

    GAMMA = 0.99

    TAU = 0.005

    BATCH_SIZE = 256

    def __init__(self, env_id, action_space, action_bound):

        self.env_id = env_id

        self.action_space = action_space

        self.action_bound = action_bound

        self.env = gym.make(self.env_id)

        self.replay_buffer = ReplayBuffer(max_len=self.MAX_EXPERIENCES)

        self.policy = GaussianPolicy(action_space=self.action_space,
                                     action_bound=self.action_bound)

        self.duqlqnet = DualQNetwork()

        self.target_dualqnet = DualQNetwork()

        self.log_alpha = tf.Variable(0.)  #: alpha=1

        self.alpha_optimizer = tf.keras.optimizers.Adam(3e-4)

        self.target_entropy = -0.5 * self.action_space

        self.global_steps = 0

        self._initialize_weights()

    def _initialize_weights(self):
        """1度callすることでネットワークの重みを初期化
        """

        env = gym.make(self.env_id)

        dummy_state = env.reset()
        dummy_state = (dummy_state[np.newaxis, ...]).astype(np.float32)

        dummy_action = np.random.normal(0, 0.1, size=self.action_space)
        dummy_action = (dummy_action[np.newaxis, ...]).astype(np.float32)

        self.policy(dummy_state)

        self.duqlqnet(dummy_state, dummy_action)
        self.target_dualqnet(dummy_state, dummy_action)
        self.target_dualqnet.set_weights(self.duqlqnet.get_weights())

    def play_episode(self):

        episode_reward = 0

        local_steps = 0

        done = False

        state = self.env.reset()

        while not done:

            action, _ = self.policy.sample_action(np.atleast_2d(state))

            action = action.numpy()[0]

            next_state, reward, done, _ = self.env.step(action)

            exp = Experience(state, action, reward, next_state, done)

            self.replay_buffer.push(exp)

            state = next_state

            episode_reward += reward

            local_steps += 1

            self.global_steps += 1

            if (len(self.replay_buffer) >= self.MIN_EXPERIENCES
               and self.global_steps % self.UPDATE_PERIOD == 0):

                self.update_networks()

        return episode_reward, local_steps, tf.exp(self.log_alpha)

    def update_networks(self):

        (states, actions, rewards,
         next_states, dones) = self.replay_buffer.get_minibatch(self.BATCH_SIZE)

        alpha = tf.math.exp(self.log_alpha)

        #: Update Q-function
        next_actions, next_logprobs = self.policy.sample_action(next_states)

        target_q1, target_q2 = self.target_dualqnet(next_states, next_actions)

        target = rewards + (1 - dones) * self.GAMMA * (
            tf.minimum(target_q1, target_q2) + -1 * alpha * next_logprobs
            )

        with tf.GradientTape() as tape:
            q1, q2 = self.duqlqnet(states, actions)
            loss_1 = tf.reduce_mean(tf.square(target - q1))
            loss_2 = tf.reduce_mean(tf.square(target - q2))
            loss = 0.5 * loss_1 + 0.5 * loss_2

        variables = self.duqlqnet.trainable_variables
        grads = tape.gradient(loss, variables)
        self.duqlqnet.optimizer.apply_gradients(zip(grads, variables))

        #: Update policy
        with tf.GradientTape() as tape:
            selected_actions, logprobs = self.policy.sample_action(states)
            q1, q2 = self.duqlqnet(states, selected_actions)
            q_min = tf.minimum(q1, q2)
            loss = -1 * tf.reduce_mean(q_min + -1 * alpha * logprobs)

        variables = self.policy.trainable_variables
        grads = tape.gradient(loss, variables)
        self.policy.optimizer.apply_gradients(zip(grads, variables))

        #: Adjust alpha
        entropy_diff = -1 * logprobs - self.target_entropy
        with tf.GradientTape() as tape:
            tape.watch(self.log_alpha)
            selected_actions, logprobs = self.policy.sample_action(states)
            alpha_loss = tf.reduce_mean(tf.exp(self.log_alpha) * entropy_diff)

        grad = tape.gradient(alpha_loss, self.log_alpha)
        self.alpha_optimizer.apply_gradients([(grad, self.log_alpha)])

        #: Soft target update
        self.target_dualqnet.set_weights(
           (1 - self.TAU) * np.array(self.target_dualqnet.get_weights())
           + self.TAU * np.array(self.duqlqnet.get_weights())
           )

    def save_model(self):

        self.policy.save_weights("checkpoints/actor")

        self.duqlqnet.save_weights("checkpoints/critic")

    def load_model(self):

        self.policy.load_weights("checkpoints/actor")

        self.duqlqnet.load_weights("checkpoints/critic")

        self.target_dualqnet.load_weights("checkpoints/critic")

    def testplay(self, n=1, monitordir=None):

        if monitordir:
            env = wrappers.Monitor(gym.make(self.env_id),
                                   monitordir, force=True,
                                   video_callable=(lambda ep: True))
        else:
            env = gym.make(self.env_id)

        total_rewards = []

        for _ in range(n):

            state = env.reset()

            done = False

            total_reward = 0

            while not done:

                action, _ = self.policy.sample_action(np.atleast_2d(state))

                action = action.numpy()[0]

                next_state, reward, done, _ = env.step(action)

                total_reward += reward

                if done:
                    break
                else:
                    state = next_state

            total_rewards.append(total_reward)
            print()
            print(total_reward)
            print()

        return total_rewards
class FQFAgent:

    def __init__(self, env_name,
                 num_quantiles=32, fqf_factor=0.000001*0.1, ent_coef=0.001,
                 state_embedding_dim=3136, quantile_embedding_dim=64,
                 gamma=0.99, n_frames=4, batch_size=32,
                 buffer_size=1000000,
                 update_period=8,
                 target_update_period=10000):

        self.env_name = env_name

        self.num_quantiles = num_quantiles

        self.state_embedding_dim = state_embedding_dim

        self.quantile_embedding_dim = quantile_embedding_dim

        self.k = 1.0

        self.ent_coef = ent_coef

        self.n_frames = n_frames

        self.action_space = gym.make(self.env_name).action_space.n

        self.fqf_network = FQFNetwork(
            action_space=self.action_space,
            num_quantiles=self.num_quantiles,
            state_embedding_dim=self.state_embedding_dim,
            quantile_embedding_dim=self.quantile_embedding_dim)

        self.target_fqf_network = FQFNetwork(
            action_space=self.action_space,
            num_quantiles=self.num_quantiles,
            state_embedding_dim=self.state_embedding_dim,
            quantile_embedding_dim=self.quantile_embedding_dim)

        self._define_network()

        self.optimizer = tf.keras.optimizers.Adam(
            lr=0.00015, epsilon=0.01/32)

        #: fpl; fraction proposal layer
        self.optimizer_fpl = tf.keras.optimizers.Adam(
            learning_rate=0.00005 * fqf_factor,
            epsilon=0.0003125)

        self.gamma = gamma

        self.replay_buffer = ReplayBuffer(max_len=buffer_size)

        self.batch_size = batch_size

        self.update_period = update_period

        self.target_update_period = target_update_period

        self.steps = 0

    def _define_network(self):
        """ initialize network weights
        """
        env = gym.make(self.env_name)
        frames = collections.deque(maxlen=4)
        frame = frame_preprocess(env.reset())
        for _ in range(self.n_frames):
            frames.append(frame)

        state = np.stack(frames, axis=2)[np.newaxis, ...]
        self.fqf_network(state)
        self.target_fqf_network(state)
        self.target_fqf_network.set_weights(self.fqf_network.get_weights())

    @property
    def epsilon(self):
        if self.steps <= 1000000:
            return max(0.99 * (1000000 - self.steps) / 1000000, 0.1)
        elif self.steps <= 2000000:
            return 0.05 + 0.05 * (2000000 - self.steps) / 2000000
        else:
            return 0.05

    def learn(self, n_episodes, logdir="log"):

        logdir = Path(__file__).parent / logdir
        if logdir.exists():
            shutil.rmtree(logdir)
        self.summary_writer = tf.summary.create_file_writer(str(logdir))

        for episode in range(1, n_episodes+1):

            env = gym.make(self.env_name)

            frames = collections.deque(maxlen=4)
            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)

            episode_rewards = 0
            episode_steps = 0
            done = False
            lives = 5
            while not done:
                self.steps += 1
                episode_steps += 1
                state = np.stack(frames, axis=2)[np.newaxis, ...]
                action = self.fqf_network.sample_action(state, epsilon=self.epsilon)
                next_frame, reward, done, info = env.step(action)
                episode_rewards += reward
                frames.append(frame_preprocess(next_frame))
                next_state = np.stack(frames, axis=2)[np.newaxis, ...]

                if done:
                    exp = Experience(state, action, reward, next_state, done)
                    self.replay_buffer.push(exp)
                    break
                else:
                    if info["ale.lives"] != lives:
                        #: life loss as episode ends
                        lives = info["ale.lives"]
                        exp = Experience(state, action, reward, next_state, True)
                    else:
                        exp = Experience(state, action, reward, next_state, done)

                    self.replay_buffer.push(exp)

                if (len(self.replay_buffer) > 50000) and (self.steps % self.update_period == 0):

                    loss, loss_fp, entropy = self.update_network()

                    with self.summary_writer.as_default():
                        tf.summary.scalar("loss", loss, step=self.steps)
                        tf.summary.scalar("loss_fp", loss_fp, step=self.steps)
                        tf.summary.scalar("entropy", entropy, step=self.steps)
                        tf.summary.scalar("epsilon", self.epsilon, step=self.steps)
                        tf.summary.scalar("buffer_size", len(self.replay_buffer), step=self.steps)
                        tf.summary.scalar("train_score", episode_rewards, step=self.steps)
                        tf.summary.scalar("train_steps", episode_steps, step=self.steps)

                #: Target update
                if self.steps % self.target_update_period == 0:
                    self.target_fqf_network.set_weights(
                        self.fqf_network.get_weights())

            print(f"Episode: {episode}, score: {episode_rewards}, steps: {episode_steps}")

            if episode % 20 == 0:
                test_scores, test_steps = self.test_play(n_testplay=1)
                with self.summary_writer.as_default():
                    tf.summary.scalar("test_score", test_scores[0], step=self.steps)
                    tf.summary.scalar("test_step", test_steps[0], step=self.steps)

            if episode % 500 == 0:
                self.fqf_network.save_weights("checkpoints/fqfnet")
                print("Model Saved")

    def update_network(self):

        (states, actions, rewards,
         next_states, dones) = self.replay_buffer.get_minibatch(self.batch_size)

        rewards = rewards.reshape((self.batch_size, 1, 1))
        dones = dones.reshape((self.batch_size, 1, 1))

        with tf.GradientTape() as tape:
            #: Compute F(τ^)
            state_embedded = self.fqf_network.state_embedding_layer(states)

            taus, taus_hat, taus_hat_probs = self.fqf_network.propose_fractions(state_embedded)
            taus_hat, taus_hat_probs = tf.stop_gradient(taus_hat), tf.stop_gradient(taus_hat_probs)

            quantiles = self.fqf_network.quantile_function(
                state_embedded, taus_hat)
            actions_onehot = tf.one_hot(
                actions.flatten().astype(np.int32), self.action_space)
            actions_mask = tf.expand_dims(actions_onehot, axis=2)
            quantiles = tf.reduce_sum(
                quantiles * actions_mask, axis=1, keepdims=True)

            #: Compute target F(τ^), use same taus proposed by online network
            next_actions, target_quantiles = self.target_fqf_network.greedy_action_on_given_taus(
                next_states, taus_hat, taus_hat_probs)

            next_actions_onehot = tf.one_hot(next_actions.numpy().flatten(), self.action_space)
            next_actions_mask = tf.expand_dims(next_actions_onehot, axis=2)
            target_quantiles = tf.reduce_sum(
                target_quantiles * next_actions_mask, axis=1, keepdims=True)

            #: TF(τ^)
            target_quantiles = rewards + self.gamma * (1-dones) * target_quantiles
            target_quantiles = tf.stop_gradient(target_quantiles)

            #: Compute Quantile regression loss
            target_quantiles = tf.repeat(
                target_quantiles, self.num_quantiles, axis=1)
            quantiles = tf.repeat(
                tf.transpose(quantiles, [0, 2, 1]), self.num_quantiles, axis=2)

            #: huberloss
            bellman_errors = target_quantiles - quantiles
            is_smaller_than_k = tf.abs(bellman_errors) < self.k
            squared_loss = 0.5 * tf.square(bellman_errors)
            linear_loss = self.k * (tf.abs(bellman_errors) - 0.5 * self.k)

            huberloss = tf.where(is_smaller_than_k, squared_loss, linear_loss)

            #: quantile loss
            indicator = tf.stop_gradient(tf.where(bellman_errors < 0, 1., 0.))
            _taus_hat = tf.repeat(
                tf.expand_dims(taus_hat, axis=2), self.num_quantiles, axis=2)

            quantile_factors = tf.abs(_taus_hat - indicator)
            quantile_huberloss = quantile_factors * huberloss

            loss = tf.reduce_mean(quantile_huberloss, axis=2),
            loss = tf.reduce_sum(loss, axis=1)
            loss = tf.reduce_mean(loss)

        state_embedding_vars = self.fqf_network.state_embedding_layer.trainable_variables
        quantile_function_vars = self.fqf_network.quantile_function.trainable_variables

        variables = state_embedding_vars + quantile_function_vars
        grads = tape.gradient(loss, variables)

        with tf.GradientTape() as tape2:
            taus_all = self.fqf_network.fraction_proposal_layer(state_embedded)
            taus = taus_all[:, 1:-1]

            quantiles = self.fqf_network.quantile_function(
                state_embedded, taus)
            taus_hat = (taus_all[:, 1:] + taus_all[:, :-1]) / 2.
            quantiles_hat = self.fqf_network.quantile_function(
                state_embedded, taus_hat)

            dw_dtau = 2 * quantiles - quantiles_hat[:, :, 1:] - quantiles_hat[:, :, :-1]
            dw_dtau = tf.reduce_sum(dw_dtau * actions_mask, axis=1)

            entropy = tf.reduce_sum(-1 * taus_hat * tf.math.log(taus_hat), axis=1)

            loss_fp = tf.reduce_mean(tf.square(dw_dtau), axis=1)
            loss_fp += -1 * self.ent_coef * entropy
            loss_fp = tf.reduce_mean(loss_fp)

        fp_variables = self.fqf_network.fraction_proposal_layer.trainable_variables
        grads_fp = tape2.gradient(loss_fp, fp_variables)

        self.optimizer.apply_gradients(zip(grads, variables))
        self.optimizer_fpl.apply_gradients(zip(grads_fp, fp_variables))

        return loss, loss_fp, tf.reduce_mean(entropy)

    def test_play(self, n_testplay=1, monitor_dir=None,
                  checkpoint_path=None):

        if checkpoint_path:
            env = gym.make(self.env_name)
            frames = collections.deque(maxlen=4)
            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)
            state = np.stack(frames, axis=2)[np.newaxis, ...]
            self.fqf_network(state)
            self.fqf_network.load_weights(checkpoint_path)

        if monitor_dir:
            monitor_dir = Path(monitor_dir)
            if monitor_dir.exists():
                shutil.rmtree(monitor_dir)
            monitor_dir.mkdir()
            env = gym.wrappers.Monitor(
                gym.make(self.env_name), monitor_dir, force=True,
                video_callable=(lambda ep: True))
        else:
            env = gym.make(self.env_name)

        scores = []
        steps = []
        for _ in range(n_testplay):

            frames = collections.deque(maxlen=4)
            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)

            done = False
            episode_steps = 0
            episode_rewards = 0

            while not done:
                state = np.stack(frames, axis=2)[np.newaxis, ...]
                action = self.fqf_network.sample_action(state, epsilon=0.01)
                next_frame, reward, done, _ = env.step(action)
                frames.append(frame_preprocess(next_frame))

                episode_rewards += reward
                episode_steps += 1
                if episode_steps > 500 and episode_rewards < 3:
                    #: ゲーム開始(action: 0)しないまま停滞するケースへの対処
                    break

            scores.append(episode_rewards)
            steps.append(episode_steps)

        return scores, steps
class CategoricalDQNAgent:
    def __init__(self,
                 env_name="BreakoutDeterministic-v4",
                 n_atoms=51,
                 Vmin=-10,
                 Vmax=10,
                 gamma=0.98,
                 n_frames=4,
                 batch_size=32,
                 lr=0.00025,
                 init_epsilon=0.95,
                 update_period=8,
                 target_update_period=10000):

        self.env_name = env_name

        self.n_atoms = n_atoms

        self.Vmin, self.Vmax = Vmin, Vmax

        self.delta_z = (self.Vmax - self.Vmin) / (self.n_atoms - 1)

        self.Z = np.linspace(self.Vmin, self.Vmax, self.n_atoms)

        self.gamma = gamma

        self.n_frames = n_frames

        self.batch_size = batch_size

        self.init_epsilon = init_epsilon

        self.epsilon_scheduler = (
            lambda steps: max(0.98 * (500000 - steps) / 500000, 0.1)
            if steps < 500000 else max(
                0.05 + 0.05 * (1000000 - steps) / 500000, 0.05))

        self.update_period = update_period

        self.target_update_period = target_update_period

        env = gym.make(self.env_name)

        self.action_space = env.action_space.n

        self.qnet = CategoricalQNet(self.action_space, self.n_atoms, self.Z)

        self.target_qnet = CategoricalQNet(self.action_space, self.n_atoms,
                                           self.Z)

        self.optimizer = tf.keras.optimizers.Adam(lr=lr,
                                                  epsilon=0.01 / batch_size)

    def learn(self, n_episodes, buffer_size=800000, logdir="log"):

        logdir = Path(__file__).parent / logdir
        if logdir.exists():
            shutil.rmtree(logdir)
        self.summary_writer = tf.summary.create_file_writer(str(logdir))

        self.replay_buffer = ReplayBuffer(max_len=buffer_size)

        steps = 0
        for episode in range(1, n_episodes + 1):
            env = gym.make(self.env_name)

            frames = collections.deque(maxlen=4)
            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)

            #: ネットワーク重みの初期化
            state = np.stack(frames, axis=2)[np.newaxis, ...]
            self.qnet(state)
            self.target_qnet(state)
            self.target_qnet.set_weights(self.qnet.get_weights())

            episode_rewards = 0
            episode_steps = 0

            done = False
            lives = 5
            while not done:

                steps += 1
                episode_steps += 1

                epsilon = self.epsilon_scheduler(steps)

                state = np.stack(frames, axis=2)[np.newaxis, ...]
                action = self.qnet.sample_action(state, epsilon=epsilon)
                next_frame, reward, done, info = env.step(action)
                episode_rewards += reward
                frames.append(frame_preprocess(next_frame))
                next_state = np.stack(frames, axis=2)[np.newaxis, ...]

                if done:
                    exp = Experience(state, action, reward, next_state, done)
                    self.replay_buffer.push(exp)
                    break
                else:
                    if info["ale.lives"] != lives:
                        lives = info["ale.lives"]
                        exp = Experience(state, action, reward, next_state,
                                         True)
                    else:
                        exp = Experience(state, action, reward, next_state,
                                         done)

                    self.replay_buffer.push(exp)

                if (len(self.replay_buffer) >
                        20000) and (steps % self.update_period == 0):
                    loss = self.update_network()

                    with self.summary_writer.as_default():
                        tf.summary.scalar("loss", loss, step=steps)
                        tf.summary.scalar("epsilon", epsilon, step=steps)
                        tf.summary.scalar("buffer_size",
                                          len(self.replay_buffer),
                                          step=steps)
                        tf.summary.scalar("train_score",
                                          episode_rewards,
                                          step=steps)
                        tf.summary.scalar("train_steps",
                                          episode_steps,
                                          step=steps)

                #: Hard target update
                if steps % self.target_update_period == 0:
                    self.target_qnet.set_weights(self.qnet.get_weights())

            print(
                f"Episode: {episode}, score: {episode_rewards}, steps: {episode_steps}"
            )

            if episode % 20 == 0:
                test_scores, test_steps = self.test_play(n_testplay=1)
                with self.summary_writer.as_default():
                    tf.summary.scalar("test_score", test_scores[0], step=steps)
                    tf.summary.scalar("test_step", test_steps[0], step=steps)

            if episode % 1000 == 0:
                print("Model Saved")
                self.qnet.save_weights("checkpoints/qnet")

    def update_network(self):

        #: ミニバッチの作成
        (states, actions, rewards, next_states,
         dones) = self.replay_buffer.get_minibatch(self.batch_size)

        next_actions, next_probs = self.target_qnet.sample_actions(next_states)

        #: 選択されたactionの確率分布だけ抽出する
        onehot_mask = self.create_mask(next_actions)
        next_dists = tf.reduce_sum(next_probs * onehot_mask, axis=1).numpy()

        #: 分布版ベルマンオペレータの適用
        target_dists = self.shift_and_projection(rewards, dones, next_dists)

        onehot_mask = self.create_mask(actions)
        with tf.GradientTape() as tape:
            probs = self.qnet(states)

            dists = tf.reduce_sum(probs * onehot_mask, axis=1)
            #: クリップしないとlogとったときに勾配爆発することがある
            dists = tf.clip_by_value(dists, 1e-6, 1.0)

            loss = tf.reduce_sum(-1 * target_dists * tf.math.log(dists),
                                 axis=1,
                                 keepdims=True)
            loss = tf.reduce_mean(loss)

        grads = tape.gradient(loss, self.qnet.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.qnet.trainable_variables))

        return loss

    def shift_and_projection(self, rewards, dones, next_dists):

        target_dists = np.zeros((self.batch_size, self.n_atoms))

        for j in range(self.n_atoms):

            tZ_j = np.minimum(
                self.Vmax,
                np.maximum(self.Vmin, rewards + self.gamma * self.Z[j]))
            bj = (tZ_j - self.Vmin) / self.delta_z

            lower_bj = np.floor(bj).astype(np.int8)
            upper_bj = np.ceil(bj).astype(np.int8)

            eq_mask = lower_bj == upper_bj
            neq_mask = lower_bj != upper_bj

            lower_probs = 1 - (bj - lower_bj)
            upper_probs = 1 - (upper_bj - bj)

            next_dist = next_dists[:, [j]]
            indices = np.arange(self.batch_size).reshape(-1, 1)

            target_dists[indices[neq_mask],
                         lower_bj[neq_mask]] += (lower_probs *
                                                 next_dist)[neq_mask]
            target_dists[indices[neq_mask],
                         upper_bj[neq_mask]] += (upper_probs *
                                                 next_dist)[neq_mask]

            target_dists[indices[eq_mask],
                         lower_bj[eq_mask]] += (0.5 * next_dist)[eq_mask]
            target_dists[indices[eq_mask],
                         upper_bj[eq_mask]] += (0.5 * next_dist)[eq_mask]
        """ 2. doneへの対処
            doneのときは TZ(t) = R(t)
        """
        for batch_idx in range(self.batch_size):

            if not dones[batch_idx]:
                continue
            else:
                target_dists[batch_idx, :] = 0
                tZ = np.minimum(self.Vmax,
                                np.maximum(self.Vmin, rewards[batch_idx]))
                bj = (tZ - self.Vmin) / self.delta_z

                lower_bj = np.floor(bj).astype(np.int32)
                upper_bj = np.ceil(bj).astype(np.int32)

                if lower_bj == upper_bj:
                    target_dists[batch_idx, lower_bj] += 1.0
                else:
                    target_dists[batch_idx, lower_bj] += 1 - (bj - lower_bj)
                    target_dists[batch_idx, upper_bj] += 1 - (upper_bj - bj)

        return target_dists

    def create_mask(self, actions):

        mask = np.ones((self.batch_size, self.action_space, self.n_atoms))
        actions_onehot = tf.one_hot(tf.cast(actions, tf.int32),
                                    self.action_space,
                                    axis=1)

        for idx in range(self.batch_size):
            mask[idx, ...] = mask[idx, ...] * actions_onehot[idx, ...]

        return mask

    def test_play(self, n_testplay=1, monitor_dir=None, checkpoint_path=None):

        if checkpoint_path:
            env = gym.make(self.env_name)
            frames = collections.deque(maxlen=4)
            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)
            state = np.stack(frames, axis=2)[np.newaxis, ...]
            self.qnet(state)
            self.qnet.load_weights(checkpoint_path)

        if monitor_dir:
            monitor_dir = Path(monitor_dir)
            if monitor_dir.exists():
                shutil.rmtree(monitor_dir)
            monitor_dir.mkdir()
            env = gym.wrappers.Monitor(gym.make(self.env_name),
                                       monitor_dir,
                                       force=True,
                                       video_callable=(lambda ep: True))
        else:
            env = gym.make(self.env_name)

        scores = []
        steps = []
        for _ in range(n_testplay):

            frames = collections.deque(maxlen=4)

            frame = frame_preprocess(env.reset())
            for _ in range(self.n_frames):
                frames.append(frame)

            done = False
            episode_steps = 0
            episode_rewards = 0

            while not done:
                state = np.stack(frames, axis=2)[np.newaxis, ...]
                action = self.qnet.sample_action(state, epsilon=0.1)
                next_frame, reward, done, info = env.step(action)
                frames.append(frame_preprocess(next_frame))

                episode_rewards += reward
                episode_steps += 1
                if episode_steps > 500 and episode_rewards < 3:
                    #: ゲーム開始(action: 0)しないまま停滞するケースへの対処
                    break

            scores.append(episode_rewards)
            steps.append(episode_steps)

        return scores, steps