async def policy_value_fn_queue(self, state, loop):
        bb = BaseChessBoard(state.statestr)
        statestr = bb.get_board_arr()
        net_x = np.transpose(
            boardarr2netinput(statestr, state.get_current_player()), [1, 2, 0])
        net_x = np.expand_dims(net_x, 0)
        future = await self.push_queue(net_x, loop)
        await future
        policyout, valout = future.result()
        policyout, valout = policyout, valout[0]
        legal_move = GameBoard.get_legal_moves(state.statestr,
                                               state.get_current_player())
        #if state.currentplayer == 'b':
        #    legal_move = board.flipped_uci_labels(legal_move)
        legal_move = set(legal_move)
        legal_move_b = set(board.flipped_uci_labels(legal_move))

        action_probs = []
        if state.currentplayer == 'b':
            for move, prob in zip(uci_labels, policyout):
                if move in legal_move_b:
                    move = board.flipped_uci_labels([move])[0]
                    action_probs.append((move, prob))
        else:
            for move, prob in zip(uci_labels, policyout):
                if move in legal_move:
                    action_probs.append((move, prob))
        #action_probs = sorted(action_probs,key=lambda x:x[1])
        return action_probs, valout
示例#2
0
async def policy_value_fn_async_batch(state):
    #bb = BaseChessBoard(state.statestr)
    #statestr = bb.get_board_arr()
    #net_x = np.transpose(boardarr2netinput(statestr,state.get_current_player()),[1,2,0])
    #net_x = np.expand_dims(net_x,0)
    
    #policyout,valout = sess.run([net_softmax,value_head],feed_dict={X:net_x,training:False})
    result = work.delay((state.statestr,state.get_current_player()))
    while True:
        if result.ready():
            policyout,valout = result.get()
            break
        else:
            await asyncio.sleep(1e-3)
    #policyout,valout = policyout[0],valout[0][0]
    policyout,valout = policyout,valout
    del result
    legal_move = GameBoard.get_legal_moves(state.statestr,state.get_current_player())
    #if state.currentplayer == 'b':
    #    legal_move = board.flipped_uci_labels(legal_move)
    legal_move = set(legal_move)
    legal_move_b = set(board.flipped_uci_labels(legal_move))
    
    action_probs = []
    if state.currentplayer == 'b':
        for move,prob in zip(uci_labels,policyout):
            if move in legal_move_b:
                move = board.flipped_uci_labels([move])[0]
                action_probs.append((move,prob))
    else:
        for move,prob in zip(uci_labels,policyout):
            if move in legal_move:
                action_probs.append((move,prob))
    action_probs = sorted(action_probs,key=lambda x:x[1])
    return action_probs, valout