コード例 #1
0
class TestAlphaZero(unittest.TestCase):
    def setUp(self):
        self.env = MockEnv()
        self.nnet = MockNNet()
        self.args = dotdict({
            'simulation_num': 100,
            'c_puct': 5,
            'save_weights_path': '',
            'rows': 1,
            'columns': 3,
            'max_sample_pool_size': 100000,
            'sample_pool_file': '',
            'temp_step': 5,
        })
        self.mcts = MCTS(self.nnet, self.env, self.args)
        self.rl = RL(self.nnet, self.env, self.args)

    def test_mcts(self):
        board, player = self.env.get_initial_state()
        self.mcts.simulate(board, player)
        print("visit_count", self.mcts.visit_count)
        print("mean_action_value", self.mcts.mean_action_value)
        print("prior_probability", self.mcts.prior_probability)
        print("terminal_state", self.mcts.terminal_state)
        print("total_visit_count", self.mcts.total_visit_count)
        print("available_actions", self.mcts.available_actions)

        counts = self.mcts.visit_count[board]
        self.assertTrue(counts[0] < counts[1])
        self.assertTrue(counts[-1] < counts[-2])

    def test_play_against_itself(self):
        samples = self.rl.play_against_itself()
        print("play_against_itself", samples)
コード例 #2
0
    def play_against_itself(self):
        board, player = self.env.get_initial_state()
        boards, players, policies = [], [], []
        mcts = MCTS(self.nnet, self.env, self.args)
        for i in itertools.count():
            actions, counts = mcts.simulate(board, player)
            pi = counts / sum(counts)
            policy = numpy.zeros(self.args.rows * self.args.columns)
            policy[actions] = pi
            boards.append(board)
            players.append(player)
            policies.append(policy)

            proba = 0.75 * pi + 0.25 * numpy.random.dirichlet(
                0.3 * numpy.ones(len(pi)))
            action = actions[numpy.argmax(
                proba)] if i >= self.args.temp_step else numpy.random.choice(
                    actions, p=proba)

            next_board, next_player = self.env.next_state(
                board, action, player)
            winner = self.env.is_terminal_state(next_board, action, player)
            if winner is not None:
                logging.info("winner: %c", winner)
                values = [
                    0 if winner == ChessType.EMPTY else
                    (1 if player == winner else -1) for player in players
                ]
                return [i for i in zip(boards, players, policies, values)]
            board, player = next_board, next_player
コード例 #3
0
    def __init__(self, game, network, device='cpu', verbose=False):
        self.device = device
        self.game = game
        self.network = network.to(self.device)
        self.competitor = self.network.clone().to(self.device)
        self.competitor.load_state_dict(self.network.state_dict())

        self.verbose = verbose

        self.mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
コード例 #4
0
 def setUp(self):
     self.env = MockEnv()
     self.nnet = MockNNet()
     self.args = dotdict({
         'simulation_num': 100,
         'c_puct': 5,
         'save_weights_path': '',
         'rows': 1,
         'columns': 3,
         'max_sample_pool_size': 100000,
     })
     self.mcts = MCTS(self.nnet, self.env, self.args)
     self.rl = RL(self.nnet, self.env, self.args)
コード例 #5
0
    def launch_arena(self, player2='random', N=20, verbose=False):
        if player2 == 'random':
            player2 = lambda board : np.random.choice(np.where(self.game.valid_moves(board) != 0)[0])
        elif player2 == 'self':
            competitor_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
            player2 = lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax()
            
        if isinstance(player2, str):
            raise TypeError("other values for player2 not supported")

        current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)

        arena = Arena(
            lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
            player2,
            self.game
        )

        wins, draws, losses = arena.play(N, verbose=verbose)
        print(f"Current network achieves {wins} wins, {draws} draws, and {losses} losses over player2")
        return wins, draws, losses
コード例 #6
0
    def compare(self, path):
        """
        compare matches the current network against all the saved weights
        in the path directory.
        """
        competitor = self.network.clone().to(self.device)

        for file in natsorted(os.listdir(path))[-2:]:
            file = os.path.join(path, file)
            if file.endswith(".pt"):
                try:
                    competitor_state_dict = torch.load(file)
                    incompatible_keys = competitor.load_state_dict(competitor_state_dict)
                except:
                    print(f"Unable to load state dict for {file}")
                    continue

                current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
                competitor_mcts = MCTS(self.game, competitor, device=self.device, verbose=self.verbose)

                arena = Arena(
                    lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                    lambda board : competitor_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                    self.game
                )

                wins, draws, losses = arena.play(10)
                print(f"Current network achieves {wins} wins, {draws} draws, and {losses} losses over {file}")
コード例 #7
0
class GomokuBattleAgent(BattleAgent):
    def __init__(self, nnet, env, args):
        self.nnet = nnet
        self.env = env
        self.args = args
        self.mcts = MCTS(self.nnet, self.env, self.args)

    def next(self, sgf, player):
        actions, counts = self.mcts.simulate(sgf, player)
        pi = counts / sum(counts)
        index = numpy.argmax(0.75 * pi + 0.25 *
                             numpy.random.dirichlet(0.3 * numpy.ones(len(pi))))
        action = actions[index]
        return {
            'rowIndex': action // self.args.rows,
            'columnIndex': action % self.args.rows
        }
コード例 #8
0
 def __init__(self, nnet, env, args):
     self.nnet = nnet
     self.env = env
     self.args = args
     self.mcts = MCTS(self.nnet, self.env, self.args)
コード例 #9
0
    def learn(self):
        """
        train the network on self-play games.
        """        
        train_examples = []

        for epoch in range(NUM_ITER):
            print(f"epoch {epoch}/{NUM_ITER}")
            
            self.network.train()

            epoch_examples = []

            # pool = multiprocessing.Pool(multiprocessing.cpu_count())
            # epoch_examples = pool.map(self.launch_episode, list(range(NUM_EPISODES)))

            for episode in range(NUM_EPISODES):
                self.mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
                epoch_examples += self.episode(steps=NUM_MCTS_STEPS)

            train_examples.append(epoch_examples)

            if len(train_examples) > MAX_TRAIN_SIZE:
                train_examples.pop(0)
            
            flattened = []
            for e in train_examples:
                flattened.extend(e)
            
            data = torch.stack([episode[0] for episode in flattened])
            policy = torch.stack([episode[1] for episode in flattened])
            values = torch.tensor([episode[2] for episode in flattened])
            
            # for i in range(data.shape[0]): # debugging
            #     print(i)
            #     board = data[i]
            #     valid_moves = self.game.valid_moves(board)
            #     state_values = []
            #     found = False
            #     for j, valid in enumerate(list(valid_moves)):
            #         if not valid:
            #             continue
            #         n_board = self.game.move(board, j)

            #         if self.game.ended(n_board):
            #             if not found:
            #                 self.game.display(board)

            #             found = True
            #             print("ALL GOOD")
            #             state_values.append(self.game.reward(n_board))
            #             print(values[i], self.game.reward(n_board))
                    
                # if len(state_values) != 0:
                #     print(values[i], state_values)
                #     assert any([abs(values[i]-value) < 1e-3 for value in state_values])

            # print(list(self.network.parameters())[0].data[0])

            self.network.fit(data.to(self.device), [policy.to(self.device), values.to(self.device)], batch_size=BATCH_SIZE, epochs=TRAIN_EPOCHS, shuffle=False)

            self.network.eval()
            
            current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
            competitor_mcts = MCTS(self.game, self.competitor, device=self.device, verbose=self.verbose)

            arena = Arena(
                lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                lambda board : competitor_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                self.game
            )

            wins, draws, losses = arena.play(NUM_SELF_PLAY, verbose=False)
            print(f"New network achieves {wins} wins, {draws} draws, and {losses} losses over the previous iteration.")
            if wins + losses == 0 or wins / (wins + losses) < UPDATE_THRESHOLD: # failed to win enough
                print(f"[INFO] failed to achieve a {UPDATE_THRESHOLD} win rate. reverting to previous version.")
                self.network.load_state_dict(self.competitor.state_dict()) # reject, network reverts
            else:
                print(f"[INFO] accepted new version.")
                self.competitor.load_state_dict(self.network.state_dict()) # accept, competitor is current
            
            if epoch % SAVE_FREQ == 0:
                torch.save(self.network.state_dict(), f"backups/network-{epoch}.pt")

            if epoch % TEST_ALL_FREQ == 0:
                self.compare("backups")
コード例 #10
0
class Train:
    def __init__(self, game, network, device='cpu', verbose=False):
        self.device = device
        self.game = game
        self.network = network.to(self.device)
        self.competitor = self.network.clone().to(self.device)
        self.competitor.load_state_dict(self.network.state_dict())

        self.verbose = verbose

        self.mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)

    def episode(self, steps=25, symmetries=True):
        examples = []
        board = self.game.reset()
        curr_player = 0
        episode_step = 0
        
        if self.verbose: self.game.display(board)
        
        while True:
            temp = int(episode_step < TEMP_THRESHOLD)

            pi = self.mcts.get_action_prob(board, temp=temp, steps=steps)

            if self.verbose: print(pi)
            
            if symmetries:
                for b, p in self.game.get_symmetries(self.game.tensor(board), pi):
                    examples.append((b, p, curr_player))
            else:
                examples.append((self.game.tensor(board), pi, curr_player))

            action = np.random.choice(len(pi), p=np.array(pi))

            board = self.game.move(board, action)
            if self.verbose: 
                if curr_player == 0:
                    self.game.display(board)
                else:
                    self.game.display(self.game.flip_board(board))
            
            if game.ended(board):
                reward = self.game.reward(board, player=curr_player)
                if self.verbose: print(f"Reward: {reward}, Player: {curr_player}")
                return [(board, pi, reward * (-1) ** (curr_player != past_player)) for (board, pi, past_player) in examples]

            board = self.game.flip_board(board)
            curr_player = 1 - curr_player
            episode_step += 1

    def play(self, player2):
        """
        play a game against a human player using keyboard input.
        """
        self.network.eval()

        board = self.game.reset()
        self.game.display(board)

        while True:
            action = int(input("move: "))
            board = self.game.move(board, action)
            self.game.display(board)

            reward = self.game.reward(board, player=0)

            if reward == 1:
                print(f"You win!")
                return
            elif reward == -1:
                print("The computer wins!")
                return
            elif game.ended(board):
                print("Tie!")
                return

            board = self.game.flip_board(board)
            pi = self.mcts.get_action_prob(board, steps=NUM_MCTS_STEPS)
            _, value = self.network.predict(self.game.tensor(board).to(self.device))
            print(f"action_probs: {pi}, value: {value}")

            action = pi.argmax()
            print("move: ", action.item())

            board = self.game.move(board, action)

            self.game.display(self.game.flip_board(board))

            reward = self.game.reward(board, player=1)

            if reward == 1:
                print(f"The computer win!")
                return
            elif reward == -1:
                print("You win!")
                return
            elif game.ended(board):
                print("Tie!")
                return

            board = self.game.flip_board(board)

    def launch_arena(self, player2='random', N=20, verbose=False):
        if player2 == 'random':
            player2 = lambda board : np.random.choice(np.where(self.game.valid_moves(board) != 0)[0])
        elif player2 == 'self':
            competitor_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
            player2 = lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax()
            
        if isinstance(player2, str):
            raise TypeError("other values for player2 not supported")

        current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)

        arena = Arena(
            lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
            player2,
            self.game
        )

        wins, draws, losses = arena.play(N, verbose=verbose)
        print(f"Current network achieves {wins} wins, {draws} draws, and {losses} losses over player2")
        return wins, draws, losses

    def learn(self):
        """
        train the network on self-play games.
        """        
        train_examples = []

        for epoch in range(NUM_ITER):
            print(f"epoch {epoch}/{NUM_ITER}")
            
            self.network.train()

            epoch_examples = []

            # pool = multiprocessing.Pool(multiprocessing.cpu_count())
            # epoch_examples = pool.map(self.launch_episode, list(range(NUM_EPISODES)))

            for episode in range(NUM_EPISODES):
                self.mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
                epoch_examples += self.episode(steps=NUM_MCTS_STEPS)

            train_examples.append(epoch_examples)

            if len(train_examples) > MAX_TRAIN_SIZE:
                train_examples.pop(0)
            
            flattened = []
            for e in train_examples:
                flattened.extend(e)
            
            data = torch.stack([episode[0] for episode in flattened])
            policy = torch.stack([episode[1] for episode in flattened])
            values = torch.tensor([episode[2] for episode in flattened])
            
            # for i in range(data.shape[0]): # debugging
            #     print(i)
            #     board = data[i]
            #     valid_moves = self.game.valid_moves(board)
            #     state_values = []
            #     found = False
            #     for j, valid in enumerate(list(valid_moves)):
            #         if not valid:
            #             continue
            #         n_board = self.game.move(board, j)

            #         if self.game.ended(n_board):
            #             if not found:
            #                 self.game.display(board)

            #             found = True
            #             print("ALL GOOD")
            #             state_values.append(self.game.reward(n_board))
            #             print(values[i], self.game.reward(n_board))
                    
                # if len(state_values) != 0:
                #     print(values[i], state_values)
                #     assert any([abs(values[i]-value) < 1e-3 for value in state_values])

            # print(list(self.network.parameters())[0].data[0])

            self.network.fit(data.to(self.device), [policy.to(self.device), values.to(self.device)], batch_size=BATCH_SIZE, epochs=TRAIN_EPOCHS, shuffle=False)

            self.network.eval()
            
            current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
            competitor_mcts = MCTS(self.game, self.competitor, device=self.device, verbose=self.verbose)

            arena = Arena(
                lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                lambda board : competitor_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                self.game
            )

            wins, draws, losses = arena.play(NUM_SELF_PLAY, verbose=False)
            print(f"New network achieves {wins} wins, {draws} draws, and {losses} losses over the previous iteration.")
            if wins + losses == 0 or wins / (wins + losses) < UPDATE_THRESHOLD: # failed to win enough
                print(f"[INFO] failed to achieve a {UPDATE_THRESHOLD} win rate. reverting to previous version.")
                self.network.load_state_dict(self.competitor.state_dict()) # reject, network reverts
            else:
                print(f"[INFO] accepted new version.")
                self.competitor.load_state_dict(self.network.state_dict()) # accept, competitor is current
            
            if epoch % SAVE_FREQ == 0:
                torch.save(self.network.state_dict(), f"backups/network-{epoch}.pt")

            if epoch % TEST_ALL_FREQ == 0:
                self.compare("backups")
        
    def compare(self, path):
        """
        compare matches the current network against all the saved weights
        in the path directory.
        """
        competitor = self.network.clone().to(self.device)

        for file in natsorted(os.listdir(path))[-2:]:
            file = os.path.join(path, file)
            if file.endswith(".pt"):
                try:
                    competitor_state_dict = torch.load(file)
                    incompatible_keys = competitor.load_state_dict(competitor_state_dict)
                except:
                    print(f"Unable to load state dict for {file}")
                    continue

                current_mcts = MCTS(self.game, self.network, device=self.device, verbose=self.verbose)
                competitor_mcts = MCTS(self.game, competitor, device=self.device, verbose=self.verbose)

                arena = Arena(
                    lambda board : current_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                    lambda board : competitor_mcts.get_action_prob(board, steps=NUM_MCTS_STEPS, temp=0).argmax(),
                    self.game
                )

                wins, draws, losses = arena.play(10)
                print(f"Current network achieves {wins} wins, {draws} draws, and {losses} losses over {file}")