Exemplo n.º 1
0
 def simulate(self, ts, board, player):
     from record import Record
     from value_network import NoActionException
     records = Record()
     while True:
         try:
             bd = board.copy()
             board_str = util.board_str(board)
             valid_action = rule.valid_actions(board, player)
             while True:
                 (from_,
                  act), q = self.epsilon_greedy(board, player, valid_action,
                                                ts)
                 if (board_str, from_, act) not in self.predicts or len(
                         ts.root.sub_edge) == 1:
                     break
                 ts.root.sub_edge = [
                     e for e in ts.root.sub_edge if e.a != (from_, act)
                 ]
                 valid_action.remove((from_, act))
             assert board[from_] == player
             ts.move_down(board, player, action=(from_, act))
             if self.episode % 10 == 0:
                 logger.info('action:%s,%s', from_, act)
                 logger.info('q is %s', q)
             to_ = tuple(np.add(from_, rule.actions_move[act]))
             command, eat = rule.move(board, from_, to_)
             records.add3(bd, from_, act, len(eat), win=command == rule.WIN)
         except NoActionException:
             # 随机初始化局面后一方无路可走
             return Record(), 0
         except Exception as ex:
             logging.warning('board is:\n%s', board)
             logging.warning('player is: %s', player)
             valid = rule.valid_actions(board, player)
             logging.warning('valid is:\n%s', valid)
             logging.warning('from_:%s, act:%s', from_, act)
             ts.show_info()
             records.save('records/train/1st_')
             raise ex
         if command == rule.WIN:
             logging.info('%s WIN, step use: %s, epsilon:%s', str(player),
                          records.length(), self.epsilon)
             return records, player
         if records.length() > 10000:
             logging.info('走子数过多: %s', records.length())
             return Record(), 0
         player = -player
         board = rule.flip_board(board)
Exemplo n.º 2
0
def simulate(nw0, nw1, activation, init='fixed'):
    np.random.seed(util.rand_int32())
    player = 1 if np.random.random() > 0.5 else -1
    logger.info('init:%s, player:%s', init, player)
    board = rule.init_board(
        player) if init == 'fixed' else rule.random_init_board()
    records = Record()
    # full_records = Record()
    boards = set()  # {(board,player)}
    nws = [None, nw0, nw1]
    n_steps = 0
    while True:
        nw = nws[player]  # nw0 if player == 1 else nw1
        try:
            bd = board.copy()
            board_str = util.board_str(board)

            if (board_str, player) in boards:
                # 找出环,并将目标置为0.5进行训练,然后将环清除
                finded = False
                records2 = Record()
                for i in range(len(boards) - 1, -1, -1):
                    b, f, a, _, _ = records[i]
                    if (b == board).all() and b[f] == player:
                        finded = True
                        break
                assert finded, (board, player)
                records2.records = records.records[i:]
                records2.draw()
                nw0.train(records2)
                nw1.train(records2)

                # 将环里的数据清除
                records.records = records.records[:i]
                for b, f, a, _, _ in records2:
                    boards.remove((util.board_str(b), b[f]))
                logger.info('环:%s, records:%s, epsilon:%s', len(records2),
                            records.length(), nw.epsilon)
            boards.add((board_str, player))

            from_, action = nw.policy(board, player)
            assert board[from_] == player
            to_ = tuple(np.add(from_, rule.actions_move[action]))
            command, eat = rule.move(board, from_, to_)
            reward = len(eat)
            if activation == 'sigmoid':
                records.add3(bd,
                             from_,
                             action,
                             reward,
                             win=command == rule.WIN)
                # full_records.add3(bd, from_, action, reward, win=command==rule.WIN)
            elif activation == 'linear':
                records.add2(bd,
                             from_,
                             action,
                             reward,
                             win=command == rule.WIN)
                # full_records.add2(bd, from_, action, reward, win=command == rule.WIN)
            elif activation == 'selu':
                records.add4(bd,
                             from_,
                             action,
                             reward,
                             win=command == rule.WIN)
                # full_records.add4(bd, from_, action, reward, win=command == rule.WIN)
            else:
                raise ValueError
            if command == rule.WIN:
                logging.info('%s WIN, stone:%s, step use: %s, epsilon:%s',
                             str(player), (board == player).sum(),
                             records.length(), nw.epsilon)
                return records, player
            if n_steps - records.length() > 500:
                logging.info('循环走子数过多: %s', records.length())
                # 走子数过多,和棋
                records.clear()
                return records, 0

            player = -player
            if init == 'fixed':
                board = rule.flip_board(board)
            n_steps += 1
        except NoActionException:
            # 随机初始化局面后一方无路可走
            return Record(), 0
        except Exception as e:
            logging.info('board is:\n%s', board)
            logging.info('player is: %s', player)
            valid = rule.valid_actions(board, player)
            logging.info('valid is:\n%s', valid)
            logging.info('predict is:\n%s', nw.q_value)
            logging.info('valid action is:\n%s', nw.valid)
            logging.info('from:%s, action:%s', from_, action)
            records.save('records/train/1st_')
            raise e