Exemple #1
0
class Leaner:
    def __init__(self, config: MuZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
        self.config = config
        self.storage = storage
        self.replay_buffer = replay_buffer
        self.summary = create_summary(name="leaner")
        self.metrics_loss = Mean(f'leaner-loss', dtype=tf.float32)
        self.network = Network(self.config)
        self.lr_schedule = ExponentialDecay(
            initial_learning_rate=self.config.lr_init,
            decay_steps=self.config.lr_decay_steps,
            decay_rate=self.config.lr_decay_rate)
        self.optimizer = Adam(learning_rate=self.lr_schedule)

    def start(self):
        while self.network.training_steps() < self.config.training_steps:
            if ray.get(self.replay_buffer.size.remote()) > 0:

                self.train()

                if self.network.training_steps(
                ) % self.config.checkpoint_interval == 0:
                    weigths = self.network.get_weights()
                    self.storage.update_network.remote(weigths)

                if self.network.training_steps(
                ) % self.config.save_interval == 0:
                    self.network.save()

        print("Finished")

    def train(self):
        batch = ray.get(self.replay_buffer.sample_batch.remote())

        with tf.GradientTape() as tape:
            loss = self.network.loss_function(batch)

        grads = tape.gradient(loss, self.network.get_variables())
        self.optimizer.apply_gradients(zip(grads,
                                           self.network.get_variables()))

        self.metrics_loss(loss)
        with self.summary.as_default():
            tf.summary.scalar(f'loss', self.metrics_loss.result(),
                              self.network.training_steps())
        self.metrics_loss.reset_states()

        self.network.update_training_steps()
Exemple #2
0
class Actor:
    def __init__(self,
                 config: MuZeroConfig,
                 storage: SharedStorage,
                 replay_buffer: ReplayBuffer,
                 temperature: float = 1.0):
        self.config = config
        self.network = Network(self.config)
        self.storage = storage
        self.replay_buffer = replay_buffer
        self.temperature = temperature
        self.name = f"games-{temperature}"
        self.summary = create_summary(name=self.name)
        self.games_played = 0
        self.metrics_games = Sum(self.name, dtype=tf.int32)
        self.metrics_temperature = Sum(self.name, dtype=tf.float32)
        self.metrics_rewards = Mean(self.name, dtype=tf.float32)
        self.started = False

    def update_metrics(self):
        with self.summary.as_default():
            tf.summary.scalar(f'games-played', self.metrics_games.result(),
                              self.games_played)
            tf.summary.scalar(f'games-temperature',
                              self.metrics_temperature.result(),
                              self.games_played)
            tf.summary.scalar(f'games-rewards', self.metrics_rewards.result(),
                              self.games_played)

        self.metrics_temperature.reset_states()
        self.metrics_rewards.reset_states()

    def start(self):
        while self.games_played < self.config.training_steps:
            game = self.play_game()
            self.games_played += 1
            self.metrics_games(1)
            self.metrics_rewards(sum(game.rewards))
            self.update_metrics()

            self.replay_buffer.save_game.remote(game)

            if not self.started:
                self.started = ray.get(self.storage.started.remote())
                continue

            if self.games_played % self.config.checkpoint_interval == 0:
                weights = ray.get(self.storage.get_network_weights.remote())
                self.network.set_weights(weights)

        print(f"Actor: {self.name } finished.")

    def play_game(self) -> Game:
        game = Game(self.config.discount)
        min_max_stats = MinMaxStats(self.config.known_bounds)

        # Use Exponential Decay to reduce temperature over time
        temperature = max(
            self.temperature * (1 - self.config.temperature_decay_factor)**
            self.network.training_steps(), self.config.temperature_min)
        self.metrics_temperature(temperature)

        while not game.terminal() and len(
                game.history) < self.config.max_moves:

            # At the root of the search tree we use the representation function to
            # obtain a hidden state given the current observation.
            root = Node(0)
            current_observation = game.get_observation_from_index(-1)
            network_output = self.network.initial_inference(
                current_observation)
            expand_node(root, game.to_play(), game.legal_actions(),
                        network_output)
            backpropagate([root], network_output.value, game.to_play(),
                          self.config.discount, min_max_stats)
            add_exploration_noise(self.config, root)

            # We then run a Monte Carlo Tree Search using only action sequences and the
            # model learned by the network.
            run_mcts(self.config, root, game.action_history(), self.network,
                     min_max_stats)
            action = select_action(root, temperature)
            game.apply(action)
            game.store_search_statistics(root)

        return game