Exemplo n.º 1
0
    def _expand(self, state, current_player):

        s = self.state_to_str(state, current_player)

        with tf.device("/cpu:0"):
            nn_policy, nn_value = self.network.predict(
                othello.encode_state(state, current_player))

        nn_policy, nn_value = nn_policy.numpy().tolist()[0], nn_value.numpy(
        )[0][0]

        self.P[s] = nn_policy
        self.N[s] = [0] * othello.ACTION_SPACE
        self.W[s] = [0] * othello.ACTION_SPACE

        valid_actions = othello.get_valid_actions(state, current_player)

        #: cache valid actions and next state to save computation
        self.next_states[s] = [
            othello.step(state, action, current_player)[0] if
            (action in valid_actions) else None
            for action in range(othello.ACTION_SPACE)
        ]

        return nn_value
Exemplo n.º 2
0
    def setup(self):

        state = othello.get_initial_state()

        self.network = AlphaZeroResNet(action_space=othello.ACTION_SPACE)

        self.network.predict(othello.encode_state(state, 1))

        self.network.load_weights(self.weights_path)

        self.mcts = MCTS(network=self.network, alpha=None)
def selfplay(weights, num_mcts_simulations, dirichlet_alpha):

    record = []

    state = othello.get_initial_state()

    network = AlphaZeroResNet(action_space=othello.ACTION_SPACE)

    network.predict(othello.encode_state(state, 1))

    network.set_weights(weights)

    mcts = MCTS(network=network, alpha=dirichlet_alpha)

    current_player = 1

    done = False

    i = 0

    while not done:

        mcts_policy = mcts.search(root_state=state,
                                  current_player=current_player,
                                  num_simulations=num_mcts_simulations)

        if i <= 10:
            # For the first 30 moves of each game, the temperature is set to τ = 1;
            # this selects moves proportionally to their visit count in MCTS
            action = np.random.choice(range(othello.ACTION_SPACE),
                                      p=mcts_policy)
        else:
            action = random.choice(
                np.where(np.array(mcts_policy) == max(mcts_policy))[0])

        record.append(Sample(state, mcts_policy, current_player, None))

        next_state, done = othello.step(state, action, current_player)

        state = next_state

        current_player = -current_player

        i += 1

    #: win: 1, lose: -1, draw: 0
    reward_first, reward_second = othello.get_result(state)

    for sample in reversed(record):
        sample.reward = reward_first if sample.player == 1 else reward_second

    return record
    def get_minibatch(self, batch_size):

        indices = np.random.choice(range(len(self.buffer)), size=batch_size)

        samples = [self.buffer[idx] for idx in indices]

        states = np.stack(
            [othello.encode_state(s.state, s.player) for s in samples], axis=0)

        mcts_policy = np.array([s.mcts_policy for s in samples],
                               dtype=np.float32)

        rewards = np.array([s.reward for s in samples],
                           dtype=np.float32).reshape(-1, 1)

        return (states, mcts_policy, rewards)
        self.bn1 = kl.BatchNormalization()
        self.conv2 = kl.Conv2D(filters, kernel_size=3, padding="same",
                               use_bias=use_bias, kernel_regularizer=l2(0.001),
                               kernel_initializer="he_normal")
        self.bn2 = kl.BatchNormalization()

    def call(self, x, training=False):

        inputs = x

        x = relu(self.bn1(self.conv1(x), training=training))

        x = self.bn2(self.conv2(x), training=training)
        x = x + inputs  #: skip connection
        x = relu(x)

        return x


if __name__ == "__main__":
    import othello

    state = othello.get_initial_state()
    x = othello.encode_state(state, current_player=1)
    x = x[np.newaxis, ...]
    print(x.shape)
    action_space = othello.N_ROWS * othello.N_COLS
    #network = AlphaZeroResNet(action_space=action_space, n_blocks=5, filters=64)
    network = SimpleCNN(action_space=action_space, filters=512)
    print(network(x))
def testplay(current_weights,
             num_mcts_simulations,
             dirichlet_alpha=None,
             n_testplay=24):

    t = time.time()

    win_count = 0

    network = AlphaZeroResNet(action_space=othello.ACTION_SPACE)

    dummy_state = othello.get_initial_state()

    network.predict(othello.encode_state(dummy_state, 1))

    network.set_weights(current_weights)

    for n in range(n_testplay):

        alphazero = random.choice([1, -1])

        mcts = MCTS(network=network, alpha=dirichlet_alpha)

        state = othello.get_initial_state()

        current_player = 1

        done = False

        while not done:

            if current_player == alphazero:
                mcts_policy = mcts.search(root_state=state,
                                          current_player=current_player,
                                          num_simulations=num_mcts_simulations)
                action = np.argmax(mcts_policy)
            else:
                action = othello.greedy_action(state,
                                               current_player,
                                               epsilon=0.3)

            next_state, done = othello.step(state, action, current_player)

            state = next_state

            current_player = -1 * current_player

        reward_first, reward_second = othello.get_result(state)

        reward = reward_first if alphazero == 1 else reward_second
        result = "win" if reward == 1 else "lose" if reward == -1 else "draw"

        if reward > 0:
            win_count += 1

        stone_first, stone_second = othello.count_stone(state)

        if alphazero == 1:
            stone_az, stone_tester = stone_first, stone_second
            color = "black"
        else:
            stone_az, stone_tester = stone_second, stone_first
            color = "white"

        message = f"AlphaZero ({color}) {result}: {stone_az} vs {stone_tester}"

        othello.save_img(state, "img", f"test_{n}.png", message)

    elapsed = time.time() - t

    return win_count, win_count / n_testplay, elapsed
def main(num_cpus,
         n_episodes=10000,
         buffer_size=40000,
         batch_size=64,
         epochs_per_update=5,
         num_mcts_simulations=50,
         update_period=300,
         test_period=300,
         n_testplay=20,
         save_period=3000,
         dirichlet_alpha=0.35):

    ray.init(num_cpus=num_cpus, num_gpus=1, local_mode=False)

    logdir = Path(__file__).parent / "log"
    if logdir.exists():
        shutil.rmtree(logdir)
    summary_writer = tf.summary.create_file_writer(str(logdir))

    network = AlphaZeroResNet(action_space=othello.ACTION_SPACE)

    #: initialize network parameters
    dummy_state = othello.encode_state(othello.get_initial_state(), 1)

    network.predict(dummy_state)

    current_weights = ray.put(network.get_weights())

    #optimizer = tf.keras.optimizers.SGD(lr=lr, momentum=0.9)
    optimizer = tf.keras.optimizers.Adam(lr=0.0005)

    replay = ReplayBuffer(buffer_size=buffer_size)

    #: 並列Selfplay
    work_in_progresses = [
        selfplay.remote(current_weights, num_mcts_simulations, dirichlet_alpha)
        for _ in range(num_cpus - 2)
    ]

    test_in_progress = testplay.remote(current_weights,
                                       num_mcts_simulations,
                                       n_testplay=n_testplay)

    n_updates = 0
    n = 0
    while n <= n_episodes:

        for _ in tqdm(range(update_period)):
            #: selfplayが終わったプロセスを一つ取得
            finished, work_in_progresses = ray.wait(work_in_progresses,
                                                    num_returns=1)
            replay.add_record(ray.get(finished[0]))
            work_in_progresses.extend([
                selfplay.remote(current_weights, num_mcts_simulations,
                                dirichlet_alpha)
            ])
            n += 1

        #: Update network
        if len(replay) >= 20000:
            #if len(replay) >= 2000:

            num_iters = epochs_per_update * (len(replay) // batch_size)
            for i in range(num_iters):

                states, mcts_policy, rewards = replay.get_minibatch(
                    batch_size=batch_size)

                with tf.GradientTape() as tape:

                    p_pred, v_pred = network(states, training=True)
                    value_loss = tf.square(rewards - v_pred)

                    policy_loss = -mcts_policy * tf.math.log(p_pred + 0.0001)
                    policy_loss = tf.reduce_sum(policy_loss,
                                                axis=1,
                                                keepdims=True)

                    loss = tf.reduce_mean(value_loss + policy_loss)

                grads = tape.gradient(loss, network.trainable_variables)
                optimizer.apply_gradients(
                    zip(grads, network.trainable_variables))

                n_updates += 1

                if i % 100 == 0:
                    with summary_writer.as_default():
                        tf.summary.scalar("v_loss",
                                          value_loss.numpy().mean(),
                                          step=n_updates)
                        tf.summary.scalar("p_loss",
                                          policy_loss.numpy().mean(),
                                          step=n_updates)

            current_weights = ray.put(network.get_weights())

        if n % test_period == 0:
            print(f"{n - test_period}: TEST")
            win_count, win_ratio, elapsed_time = ray.get(test_in_progress)
            print(f"SCORE: {win_count}, {win_ratio}, Elapsed: {elapsed_time}")
            test_in_progress = testplay.remote(current_weights,
                                               num_mcts_simulations,
                                               n_testplay=n_testplay)

            with summary_writer.as_default():
                tf.summary.scalar("win_count", win_count, step=n - test_period)
                tf.summary.scalar("win_ratio", win_ratio, step=n - test_period)
                tf.summary.scalar("buffer_size", len(replay), step=n)

        if n % save_period == 0:
            network.save_weights("checkpoints/network")