예제 #1
0
파일: AI_train.py 프로젝트: amoliu/Renju-AI
def action_model(model_type, model, board, player):
    """
    :param model_type: model type
    :param model: policy model or value model, or else
    :param board: a numpy array with size (15 x 15)
    :param player: a player
    :return:
    """
    if player == "black":
        player = RenjuGame.PLAYER_BLACK
    else:
        player = RenjuGame.PLAYER_WHITE
    position = RenjuGame(board=board, player=player)
    if model_type == "policy_dl" or model_type == "policy_rl":
        state = position.get_states()
        action = model.predict([state])[0]
    elif model_type == "policy_rollout":
        # state = position.get_patterns()
        state = position.get_states(flatten=True)
        action = model.predict([state])[0]
    elif model_type == "value_net":
        state = position.get_states(player_plane=True)
        action = model.predict([state])[0]
    else:
        logger.error("not support model type=%s" % model_type)
        action = None
    return action
예제 #2
0
def action_model(model_type, model, board, player):
    """
    :param model_type: model type
    :param model: policy model or value model, or else
    :param board: a numpy array with size (15 x 15)
    :param player: a player
    :return:
    """
    if player == "black":
        player = RenjuGame.PLAYER_BLACK
    else:
        player = RenjuGame.PLAYER_WHITE
    position = RenjuGame(board=board, player=player)
    if model_type == "policy_dl" or model_type == "policy_rl":
        state = position.get_states()
        action = model.predict([state])[0]
    elif model_type == "policy_rollout":
        # state = position.get_patterns()
        state = position.get_states(flatten=True)
        action = model.predict([state])[0]
    elif model_type == "value_net":
        state = position.get_states(player_plane=True)
        action = model.predict([state])[0]
    else:
        logger.error("not support model type=%s" % model_type)
        action = None
    return action
예제 #3
0
def simulate(model_type, model, board, player, random_prob=0.95):
    if player == "black":
        player = RenjuGame.PLAYER_BLACK
    else:
        player = RenjuGame.PLAYER_WHITE
    game = RenjuGame(board=board, player=player)
    while True:  # loop game
        if model_type == "policy_dl" or model_type == "policy_rl":
            state = game.get_states()
            predict_vals = model.predict([state])[0]
        elif model_type == "policy_rollout":
            state = game.get_states(flatten=True)
            predict_vals = model.predict([state])[0]
        elif model_type == "value_net":
            state = game.get_states(player_plane=True)
            predict_vals = model.predict([state])[0]
        if random.random() < random_prob:
            action = game.choose_action(predict_vals)
        else:  # choose second best
            action = game.weighted_choose_action(predict_vals)
        if action is None:
            return 0
        _, reward_n, terminal_n = game.step_games(action)
        if terminal_n:
            return reward_n
예제 #4
0
def single_import_renju_pattern(index):
    renju_db = DBWrapper(db_path="./data/renju.db")
    pattern_db = DBWrapper(db_path="./data/patterns.db")
    row = renju_db.query("select * from renju limit 1 offset ?", index)[0]
    position = RenjuGame(board=stream_to_board(row["board"]), player=row["player"])
    pattern = ','.join(map(str, position.get_patterns()))
    action = row["action"]
    while True:
        try:
            pattern_db.execute("insert INTO pattern(pattern, player, action) VALUES (?, ?, ?)",
                                pattern, row["player"], action)
            break
        except:
            logger.warn("fail to insert into pattern_db, try again")
예제 #5
0
파일: AI_train.py 프로젝트: amoliu/Renju-AI
def simulate(model_type, model, board, player, random_prob=0.95):
    if player == "black":
        player = RenjuGame.PLAYER_BLACK
    else:
        player = RenjuGame.PLAYER_WHITE
    game = RenjuGame(board=board, player=player)
    while True:  # loop game
        if model_type == "policy_dl" or model_type == "policy_rl":
            state = game.get_states()
            predict_vals = model.predict([state])[0]
        elif model_type == "policy_rollout":
            state = game.get_states(flatten=True)
            predict_vals = model.predict([state])[0]
        elif model_type == "value_net":
            state = game.get_states(player_plane=True)
            predict_vals = model.predict([state])[0]
        if random.random() < random_prob:
            action = game.choose_action(predict_vals)
        else:  # choose second best
            action = game.weighted_choose_action(predict_vals)
        if action is None:
            return 0
        _, reward_n, terminal_n = game.step_games(action)
        if terminal_n:
            return reward_n
예제 #6
0
파일: AI_mcts.py 프로젝트: xsir317/Renju-AI
 def acquire_thread(self, player):
     for _, _thread in self.threads.items():
         if _thread.signal is SIGNAL_FREE:
             print "ai player:", player
             _thread.root = Node(RenjuGame(board=None, player=player))
             _thread.root.position.board[7][7] = RenjuGame.STONE_BLACK
             if player == "black":
                 _thread.root.position.player = RenjuGame.PLAYER_WHITE
             self.simulate(_thread.name)
             return _thread.name
     return None
예제 #7
0
def predict_model():
    with tf.device("/cpu:0"):
        with tf.name_scope('%s_%d' % (TOWER_NAME, 0)) as scope:
            states = tf.placeholder(tf.float32, [None, board_size, board_size, planes])
            logits = inference(states)
    saver = tf.train.Saver()
    init = tf.initialize_all_variables()
    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False)
    )
    sess.run(init)
    restore_model(sess, train_dir, saver)
    import pdb
    pdb.set_trace()
    board, player = None, None
    position = RenjuGame(board=board, player=player)
    state = np.array([position.get_states()], dtype=np.float32)
    action = sess.run([logits], feed_dict={states: state})
    session.run()
예제 #8
0
 def import_RenjuNet(self, file_path):
     if not os.path.exists(file_path):
         logger.error("not found file: %s" % file_path, to_exit=True)
     # read xml file
     bs_tree = BeautifulSoup(open(file_path, 'r').read())
     games = bs_tree.find_all("game")
     # insert moves
     game_num = len(games)
     move_count = 0
     step = 0
     for game in games:
         step += 1
         gid = int(game.attrs["id"])
         moves = game.move.text.strip().replace("%20", " ").split(" ")
         if len(self.db.query("select id from renju WHERE gid=?", gid)) > 0:  # when gid exists
             continue
         renju_game = RenjuGame()
         for mid, move in enumerate(moves):
             move = move.strip()
             if move == "":
                 continue
             board_stream = board_to_stream(renju_game.board)
             player = renju_game.player
             row = ord(move[0]) - ord('a')
             col = int(move[1:]) - 1
             action = renju_game.transform_action((row, col))
             # insert
             self.db.execute("insert INTO renju (gid, mid, board, player, action) VALUES (?, ?, ?, ?, ?)",
                             gid, mid, board_stream, player, action)
             # do move
             renju_game.do_move((row, col))
         move_count += len(moves)
         if step % 100 == 0:
             print "load games= %d / %d" % (step, game_num)
     logger.info("newly insert games=%d, moves=%d" % (game_num, move_count))
     print "finish import moves"
예제 #9
0
 def random_fetch_rows(self, fetch_size):
     """
     :param fetch_size:
     :return: a list of tuples (instance of RenjuGame, action of int)
     """
     ids = random.sample(self.ids, fetch_size)
     # rows = self.db.query("SELECT board,player,action FROM renju ORDER BY RANDOM() LIMIT ?", fetch_size)
     rows = self.db.query("SELECT board,player,action FROM renju where id IN (%s)" % ",".join(map(str, ids)))
     samples = []
     for row in rows:
         board = stream_to_board(row["board"])
         player = row["player"]
         action = row["action"]
         samples.append((RenjuGame(board=board, player=player), action))
     return samples
예제 #10
0
 def iterator_fetch_rows(self, batch_size):
     """
     :param batch_size:
     :return:
     """
     for offset in range(0, len(self.ids), batch_size):
         limit_no = min(len(self.ids), offset + batch_size)
         batch_ids = ','.join(map(str, self.ids[offset: limit_no]))
         rows = self.db.query("SELECT board,player,action FROM renju WHERE id in (%s)" % batch_ids)
         samples = []
         for row in rows:
             board = stream_to_board(row["board"])
             player = row["player"]
             action = row["action"]
             samples.append((RenjuGame(board=board, player=player), action))
         while len(samples) < batch_size:
             samples.append(random.choice(samples))
         yield samples
예제 #11
0
def play_games(args):
    player = args.player
    board_stream = args.board
    if board_stream != "":
        if not is_legal_stream(board_stream):
            logger.error("not legal board stream:[%s]" % board_stream,
                         to_exit=True)
        board = stream_to_board(board_stream)
    else:
        board = None
    root = RenjuGame(board=board, player=player)
    rpc = ModelRPC(args)
    mcst = MCTS(rpc,
                visit_threshold=args.mcts_visit_threshold,
                virtual_loss=args.mcts_virtual_loss,
                explore_rate=args.mcts_explore_rate,
                mix_lambda=args.mcts_mix_lambda)
    root = mcst.simulation(root)
    node, action = mcst.decision(root)
    print board
    print "action: %d", action
예제 #12
0
 def next_fetch_rows(self, batch_size):
     """
     :param batch_size:
     :return:
     """
     start_idx = self.fetch_index
     end_idx = min(self.fetch_index + batch_size, len(self.ids))
     if start_idx >= end_idx:
         self.fetch_index = 0
         self.shuffle_datas()
         start_idx, end_idx = 0, batch_size
     batch_ids = ','.join(map(str, self.ids[start_idx: end_idx]))
     rows = self.db.query("SELECT board,player,action FROM renju WHERE id in (%s)" % batch_ids)
     samples = []
     for row in rows:
         board = stream_to_board(row["board"])
         player = row["player"]
         action = row["action"]
         samples.append((RenjuGame(board=board, player=player), action))
     self.fetch_index = end_idx
     return samples
예제 #13
0
def sampling_for_value_network(rpc, sample_num, sample_file, max_time_steps=225):
    """
    :param max_steps: max time steps in games
    :return:
    """
    sample_games = []
    if os.path.exists(sample_file):
        sample_games = cPickle.load(open(sample_file, 'rb'))
        logger.info("load sample file: %s, samples=%d" % (sample_file, len(sample_games)))
    sample_sets = set()  # used to check unique sample
    game = RenjuGame()
    record_policy_dl_boards = []
    # move step by policy dl
    game.reset_game()
    record_policy_dl_boards.append(game.replicate_game())
    while True:
        action = game.choose_action(
            rpc.policy_dl_rpc(board_to_stream(game.board), game.get_player_name()))
        if action is None:
            break
        state, _, terminal = game.step_games(action)
        if terminal:
            break
        record_policy_dl_boards.append(game.replicate_game())
    max_time_steps = min(max_time_steps, len(record_policy_dl_boards)) - 1
    # sample game
    while len(sample_games) < sample_num:
        sampled_game = None
        while True:  # loop to find legal sample
            flag_time_step = random.randint(1, max_time_steps)
            recorded_game = record_policy_dl_boards[flag_time_step - 1].replicate_game()
            random_action = recorded_game.random_action()
            if random_action is None:
                break
            random_state, _, terminal = recorded_game.step_games(random_action)
            if not terminal and not str(random_state) in sample_sets:
                sample_sets.add(str(random_state))
                break
        if random_action is None:  # invalid loop
            continue
        # move step by policy rl
        time_step = flag_time_step
        while True:  # simulate game by policy rl
            actions = rpc.policy_rl_rpc(board_to_stream(recorded_game.board), recorded_game.get_player_name())
            action = recorded_game.choose_action(actions)
            if action is None:  # game drawn
                sampled_reward = 0
                break
            state, reward, terminal = recorded_game.step_games(action)
            time_step += 1
            if time_step == (flag_time_step + 1):  # record board
                sampled_game = recorded_game.replicate_game()
            if terminal:  # record value
                sampled_reward = reward
                break
        if sampled_game is not None:
            sample_games.append((sampled_game, sampled_reward))
            logger.info("sample simulate, sample_step=%d, time_step=%d" % (len(sample_games), time_step))
        if len(sample_games) % 100 == 0:
            cPickle.dump(sample_games, open(sample_file, "wb"), protocol=2)
            logger.info("create value network sample, step=%d" % len(sample_games))
    return sample_games
예제 #14
0
def train_rl_network(batch_games=128,
                     save_step=10000,
                     max_model_pools=5,
                     init_epsilon=0.5,
                     final_epsilon=0.01,
                     explore=1000000,
                     action_repeat=32,
                     mini_batch_size=64):
    """
        data set from self-play
    :return:
    """
    args = parser_argument().parse_args()
    rpc = ModelRPC(args)
    game = RenjuGame()
    batch_states, batch_actions, batch_rewards = deque(), deque(), deque()
    mini_batch_states, mini_batch_actions, mini_batch_rewards = [
        0
    ] * mini_batch_size, [0] * mini_batch_size, [0] * mini_batch_size
    model_pools = []
    param_file = "%s/param.json" % train_dir
    params = param_unserierlize(param_file,
                                init_params={
                                    "global_step": 0,
                                    "epsilon": init_epsilon
                                })
    global_step_val, epsilon = params["global_step"], params["epsilon"]
    # load model
    sess, saver, summary_writer, train_op, loss, accuracy, global_step, lr, tower_feeds, tower_logits = network(
    )
    train_step = 0
    while True:
        start_time = time.time()
        # choose policy network for opponent player from model pools
        if train_step % 10 == 0:
            if len(model_pools) > 0:
                model_file = random.choice(model_pools)
            else:
                model_file = None
            rpc.switch_model("policy_rl", model_file=model_file)
        while len(batch_states) < batch_games:
            # opponent_policy = self.load_history_policy_model(model_file)
            black_opponent = random.choice([True, False])
            # reset game
            game.reset_game()
            # simulate game by current parameter
            states, actions, rewards = [], [], []
            state = game.step_games(None)
            while True:  # loop current game
                # self-play, current model V.S. history model
                if (black_opponent and game.player == RenjuGame.PLAYER_BLACK) \
                        or (not black_opponent and game.player == RenjuGame.PLAYER_WHITE):
                    predict_probs = rpc.policy_rl_rpc(
                        board_to_stream(game.board), game.get_player_name())
                else:  # current player
                    predict_probs = sess.run(
                        [tower_logits[0]],
                        feed_dict={tower_feeds[0][0]: [state]})[0][0]
                if random.random() < epsilon:  # random choose action
                    action = game.weighted_choose_action(predict_probs)
                else:
                    action = game.choose_action(predict_probs)
                if action is None:
                    final_reward = 0
                    break
                # step game
                state_n, reward_n, terminal_n = game.step_games(action)
                # store (state, action)
                states.append(state)
                actions.append(action)
                # set new states
                state = state_n
                if terminal_n:
                    final_reward = reward_n
                    break
                # check whether game drawn
                if game.random_action(
                ) is None:  # game drawn, equal end, reward=0
                    final_reward = 0
                    logger.info("game drawn, so amazing...")
                    break
            # store (reward)
            for step in xrange(len(states)):
                if step % 2 == 0:
                    rewards.append(final_reward)
                else:
                    rewards.append(-final_reward)
            # store states of ith game
            batch_states.append(states)
            batch_actions.append(actions)
            batch_rewards.append(rewards)
        # fit model by mini batch
        avg_loss, avg_acc = 0.0, 0.0
        for _ in xrange(action_repeat / gpu_num):
            train_step += 1
            feeds = {}
            for gpu_id in xrange(gpu_num):
                for idx in xrange(mini_batch_size):
                    game_idx = random.randint(0, len(batch_states) - 1)
                    game_time_step_idx = random.randint(
                        0,
                        len(batch_states[game_idx]) - 1)
                    mini_batch_states[idx] = batch_states[game_idx][
                        game_time_step_idx]
                    mini_batch_actions[idx] = batch_actions[game_idx][
                        game_time_step_idx]
                    mini_batch_rewards[idx] = batch_rewards[game_idx][
                        game_time_step_idx]
                feeds[tower_feeds[gpu_id][0]] = mini_batch_states
                feeds[tower_feeds[gpu_id][1]] = mini_batch_actions
                feeds[tower_feeds[gpu_id][2]] = mini_batch_rewards
            _, global_step_val, loss_val, acc_val = sess.run(
                [train_op, global_step, loss, accuracy], feed_dict=feeds)
            avg_loss += loss_val
            avg_acc += acc_val
            # update epsilon
            if epsilon > final_epsilon:
                epsilon -= (init_epsilon - final_epsilon) / explore
        avg_loss /= action_repeat
        avg_acc /= action_repeat
        batch_states.popleft()
        batch_actions.popleft()
        batch_rewards.popleft()

        global_step_val = int(global_step_val)
        elapsed_time = int(time.time() - start_time)
        logger.info(
            "train policy rl network, step=%d, epsilon=%.5f, loss=%.6f, acc=%.6f, time=%d(sec)"
            % (train_step, epsilon, avg_loss, avg_acc, elapsed_time))
        # save model
        if train_step % save_step == 0:
            params["global_step"], params["epsilon"] = global_step_val, epsilon
            param_serierlize(param_file, params)
            model_file = save_model(sess,
                                    train_dir,
                                    saver,
                                    "policy_rl_step_%d" % train_step,
                                    global_step=global_step_val)
            logger.info("save policy rl model, file=%s" % model_file)
            model_file = model_file[len(train_dir):]
            # add history model to pool
            model_pools.append(model_file)
            if len(model_pools
                   ) > max_model_pools:  # pop head when model pools exceed
                model_pools.pop(0)
            logger.info("model pools has files: [%s]" %
                        (", ".join(model_pools)))
예제 #15
0
def sampling_for_value_network(rpc,
                               sample_num,
                               sample_file,
                               max_time_steps=225):
    """
    :param max_steps: max time steps in games
    :return:
    """
    sample_games = []
    if os.path.exists(sample_file):
        sample_games = cPickle.load(open(sample_file, 'rb'))
        logger.info("load sample file: %s, samples=%d" %
                    (sample_file, len(sample_games)))
    sample_sets = set()  # used to check unique sample
    game = RenjuGame()
    record_policy_dl_boards = []
    # move step by policy dl
    game.reset_game()
    record_policy_dl_boards.append(game.replicate_game())
    while True:
        action = game.choose_action(
            rpc.policy_dl_rpc(board_to_stream(game.board),
                              game.get_player_name()))
        if action is None:
            break
        state, _, terminal = game.step_games(action)
        if terminal:
            break
        record_policy_dl_boards.append(game.replicate_game())
    max_time_steps = min(max_time_steps, len(record_policy_dl_boards)) - 1
    # sample game
    while len(sample_games) < sample_num:
        sampled_game = None
        while True:  # loop to find legal sample
            flag_time_step = random.randint(1, max_time_steps)
            recorded_game = record_policy_dl_boards[flag_time_step -
                                                    1].replicate_game()
            random_action = recorded_game.random_action()
            if random_action is None:
                break
            random_state, _, terminal = recorded_game.step_games(random_action)
            if not terminal and not str(random_state) in sample_sets:
                sample_sets.add(str(random_state))
                break
        if random_action is None:  # invalid loop
            continue
        # move step by policy rl
        time_step = flag_time_step
        while True:  # simulate game by policy rl
            actions = rpc.policy_rl_rpc(board_to_stream(recorded_game.board),
                                        recorded_game.get_player_name())
            action = recorded_game.choose_action(actions)
            if action is None:  # game drawn
                sampled_reward = 0
                break
            state, reward, terminal = recorded_game.step_games(action)
            time_step += 1
            if time_step == (flag_time_step + 1):  # record board
                sampled_game = recorded_game.replicate_game()
            if terminal:  # record value
                sampled_reward = reward
                break
        if sampled_game is not None:
            sample_games.append((sampled_game, sampled_reward))
            logger.info("sample simulate, sample_step=%d, time_step=%d" %
                        (len(sample_games), time_step))
        if len(sample_games) % 100 == 0:
            cPickle.dump(sample_games, open(sample_file, "wb"), protocol=2)
            logger.info("create value network sample, step=%d" %
                        len(sample_games))
    return sample_games
예제 #16
0
    def train_policy_network(self,
                             rpc,
                             batch_games=128,
                             save_step=50000,
                             max_model_pools=5,
                             init_epsilon=0.5,
                             final_epsilon=0.05,
                             explore=1000000,
                             action_repeat=20,
                             mini_batch_size=128):
        """
            data set from self-play
        :return:
        """
        game = RenjuGame()
        batch_states, batch_actions, batch_rewards = deque(), deque(), deque()
        mini_batch_states, mini_batch_actions, mini_batch_rewards = [
            0
        ] * mini_batch_size, [0] * mini_batch_size, [0] * mini_batch_size
        model_pools = []
        params = self.param_unserierlize(init_params={
            "global_step": 0,
            "epsilon": init_epsilon
        })
        global_step_val, epsilon = params["global_step"], params["epsilon"]
        train_step = 0
        while True:
            start_time = time.time()
            # choose policy network for opponent player from model pools
            if train_step % 10 == 0:
                if len(model_pools) > 0:
                    model_file = random.choice(model_pools)
                else:
                    model_file = None
                rpc.switch_model("policy_rl", model_file=model_file)
            while len(batch_states) < batch_games:
                # opponent_policy = self.load_history_policy_model(model_file)
                black_opponent = random.choice([True, False])
                # reset game
                game.reset_game()
                # simulate game by current parameter
                states, actions, rewards = [], [], []
                state = game.step_games(None)
                while True:  # loop current game
                    # self-play, current model V.S. history model
                    if random.random() < epsilon:  # random choose action
                        action = game.random_action()
                    else:
                        if (black_opponent and game.player == RenjuGame.PLAYER_BLACK) \
                                or (not black_opponent and game.player == RenjuGame.PLAYER_WHITE):
                            action = game.choose_action(
                                rpc.policy_rl_rpc(board_to_stream(game.board),
                                                  game.get_player_name()))
                        else:  # current player
                            action = game.choose_action(
                                self.predict([state])[0])
                    # step game
                    state_n, reward_n, terminal_n = game.step_games(action)
                    # print "game=", batch_step, ", move=", transform_action(action)
                    # store (state, action)
                    states.append(state)
                    one_hot_act = one_hot_action(action)
                    actions.append(one_hot_act)
                    # set new states
                    state = state_n
                    if terminal_n:
                        final_reward = reward_n
                        # logger.info("winner=%s" % ("black" if reward_n > 0 else "white"))
                        break
                    # check whether game drawn
                    if game.random_action(
                    ) is None:  # game drawn, equal end, reward=0
                        final_reward = 0
                        logger.info("game drawn, so amazing...")
                        break
                # store (reward)
                for step in xrange(len(states)):
                    if step % 2 == 0:
                        rewards.append(final_reward)
                    else:
                        rewards.append(-final_reward)
                # store states of ith game
                batch_states.append(states)
                batch_actions.append(actions)
                batch_rewards.append(rewards)
            # fit model by mini batch
            avg_loss, avg_acc = 0.0, 0.0
            for _ in xrange(action_repeat):
                train_step += 1
                for idx in xrange(mini_batch_size):
                    game_idx = random.randint(0, len(batch_states) - 1)
                    game_time_step_idx = random.randint(
                        0,
                        len(batch_states[game_idx]) - 1)
                    mini_batch_states[idx] = batch_states[game_idx][
                        game_time_step_idx]
                    mini_batch_actions[idx] = batch_actions[game_idx][
                        game_time_step_idx]
                    mini_batch_rewards[idx] = batch_rewards[game_idx][
                        game_time_step_idx]
                _, global_step_val, loss, acc = self.fit(mini_batch_states,
                                                         mini_batch_actions,
                                                         mini_batch_rewards,
                                                         fetch_info=True)
                avg_loss += loss
                avg_acc += acc
                # update epsilon
                if epsilon > final_epsilon:
                    epsilon -= (init_epsilon - final_epsilon) / explore
            avg_loss /= action_repeat
            avg_acc /= action_repeat
            batch_states.popleft()
            batch_actions.popleft()
            batch_rewards.popleft()

            global_step_val = int(global_step_val)
            elapsed_time = int(time.time() - start_time)
            logger.info(
                "train policy rl network, step=%d, epsilon=%.5f, loss=%.6f, acc=%.6f, time=%d(sec)"
                % (global_step_val, epsilon, avg_loss, avg_acc, elapsed_time))
            # save model
            if train_step % save_step == 0:
                params["global_step"], params[
                    "epsilon"] = global_step_val, epsilon
                self.param_serierlize(params)
                model_file = self.save_model("policy_rl",
                                             global_step=global_step_val)
                logger.info("save policy dl model, file=%s" % model_file)
                model_file = model_file[len(self.model_dir):]
                # add history model to pool
                model_pools.append(model_file)
                if len(model_pools
                       ) > max_model_pools:  # pop head when model pools exceed
                    model_pools.pop(0)
                logger.info("model pools has files: [%s]" %
                            (", ".join(model_pools)))
예제 #17
0
def train_rl_network(batch_games=128, save_step=10000,
                     max_model_pools=5, init_epsilon=0.5, final_epsilon=0.01, explore=1000000,
                     action_repeat=32, mini_batch_size=64):
    """
        data set from self-play
    :return:
    """
    args = parser_argument().parse_args()
    rpc = ModelRPC(args)
    game = RenjuGame()
    batch_states, batch_actions, batch_rewards = deque(), deque(), deque()
    mini_batch_states, mini_batch_actions, mini_batch_rewards = [0] * mini_batch_size, [0] * mini_batch_size, [
        0] * mini_batch_size
    model_pools = []
    param_file = "%s/param.json" % train_dir
    params = param_unserierlize(param_file, init_params={"global_step": 0, "epsilon": init_epsilon})
    global_step_val, epsilon = params["global_step"], params["epsilon"]
    # load model
    sess, saver, summary_writer, train_op, loss, accuracy, global_step, lr, tower_feeds, tower_logits = network()
    train_step = 0
    while True:
        start_time = time.time()
        # choose policy network for opponent player from model pools
        if train_step % 10 == 0:
            if len(model_pools) > 0:
                model_file = random.choice(model_pools)
            else:
                model_file = None
            rpc.switch_model("policy_rl", model_file=model_file)
        while len(batch_states) < batch_games:
            # opponent_policy = self.load_history_policy_model(model_file)
            black_opponent = random.choice([True, False])
            # reset game
            game.reset_game()
            # simulate game by current parameter
            states, actions, rewards = [], [], []
            state = game.step_games(None)
            while True:  # loop current game
                # self-play, current model V.S. history model
                if (black_opponent and game.player == RenjuGame.PLAYER_BLACK) \
                        or (not black_opponent and game.player == RenjuGame.PLAYER_WHITE):
                    predict_probs = rpc.policy_rl_rpc(board_to_stream(game.board), game.get_player_name())
                else:  # current player
                    predict_probs = sess.run([tower_logits[0]], feed_dict={tower_feeds[0][0]: [state]})[0][0]
                if random.random() < epsilon:  # random choose action
                    action = game.weighted_choose_action(predict_probs)
                else:
                    action = game.choose_action(predict_probs)
                if action is None:
                    final_reward = 0
                    break
                # step game
                state_n, reward_n, terminal_n = game.step_games(action)
                # store (state, action)
                states.append(state)
                actions.append(action)
                # set new states
                state = state_n
                if terminal_n:
                    final_reward = reward_n
                    break
                # check whether game drawn
                if game.random_action() is None:  # game drawn, equal end, reward=0
                    final_reward = 0
                    logger.info("game drawn, so amazing...")
                    break
            # store (reward)
            for step in xrange(len(states)):
                if step % 2 == 0:
                    rewards.append(final_reward)
                else:
                    rewards.append(-final_reward)
            # store states of ith game
            batch_states.append(states)
            batch_actions.append(actions)
            batch_rewards.append(rewards)
        # fit model by mini batch
        avg_loss, avg_acc = 0.0, 0.0
        for _ in xrange(action_repeat / gpu_num):
            train_step += 1
            feeds = {}
            for gpu_id in xrange(gpu_num):
                for idx in xrange(mini_batch_size):
                    game_idx = random.randint(0, len(batch_states) - 1)
                    game_time_step_idx = random.randint(0, len(batch_states[game_idx]) - 1)
                    mini_batch_states[idx] = batch_states[game_idx][game_time_step_idx]
                    mini_batch_actions[idx] = batch_actions[game_idx][game_time_step_idx]
                    mini_batch_rewards[idx] = batch_rewards[game_idx][game_time_step_idx]
                feeds[tower_feeds[gpu_id][0]] = mini_batch_states
                feeds[tower_feeds[gpu_id][1]] = mini_batch_actions
                feeds[tower_feeds[gpu_id][2]] = mini_batch_rewards
            _, global_step_val, loss_val, acc_val = sess.run([train_op, global_step, loss, accuracy], feed_dict=feeds)
            avg_loss += loss_val
            avg_acc += acc_val
            # update epsilon
            if epsilon > final_epsilon:
                epsilon -= (init_epsilon - final_epsilon) / explore
        avg_loss /= action_repeat
        avg_acc /= action_repeat
        batch_states.popleft()
        batch_actions.popleft()
        batch_rewards.popleft()

        global_step_val = int(global_step_val)
        elapsed_time = int(time.time() - start_time)
        logger.info(
            "train policy rl network, step=%d, epsilon=%.5f, loss=%.6f, acc=%.6f, time=%d(sec)" %
            (train_step, epsilon, avg_loss, avg_acc, elapsed_time))
        # save model
        if train_step % save_step == 0:
            params["global_step"], params["epsilon"] = global_step_val, epsilon
            param_serierlize(param_file, params)
            model_file = save_model(sess, train_dir, saver,
                                    "policy_rl_step_%d" % train_step,
                                    global_step=global_step_val)
            logger.info("save policy rl model, file=%s" % model_file)
            model_file = model_file[len(train_dir):]
            # add history model to pool
            model_pools.append(model_file)
            if len(model_pools) > max_model_pools:  # pop head when model pools exceed
                model_pools.pop(0)
            logger.info("model pools has files: [%s]" % (", ".join(model_pools)))