예제 #1
0
파일: selfplay.py 프로젝트: zyg1968/ZiGo
    def replay(self, pos, data=None, times=80):
        final_score = strategies.fast_score(pos)
        mcts = strategies.RandomPlayer()
        rboard = pos.copy()
        name = mp.current_process().name
        for i in range(times):
            if not config.running:
                break
            rboard.undo()
            curboard = rboard.copy()
            c=curboard.to_move
            while config.running and not curboard.is_gameover:
                move, move_values, points = mcts.suggest_move(curboard)
                curboard.win_rate = points if c == go.BLACK else -points
                data.add_from_node(curboard, move_values, points)
                if data.data_size%1024==5:
                    print("name:", name, ", data size:", data.data_size, file=sys.stderr)
                if data.data_size>25600:
                    print("name:", name, "保存训练数据……", end="")
                    data.save()
                    data.clear()
                    print("name:", name, "保存完毕!")
                curboard.play_move(move)

            score = curboard.win_rate if c==go.BLACK else -curboard.win_rate
            cscore = score-final_score
            cscore = cscore  if c == go.BLACK else -cscore
            if i%20==15:
                print("name:", name, "倒退{}步,全探查结果:{},{}{}了{:.1f}子".format(i+1, 
                    go.result_str(score), go.get_color_str(c), 
                    "增加" if cscore>0 else "减少", abs(cscore)), file=sys.stderr)
예제 #2
0
파일: selfplay.py 프로젝트: zyg1968/ZiGo
    def play(self, pos=None, forbid_pass=True, sgfdir='sgf'):
        if not pos:
            pos = board.Board(komi=7.5)
        if self.qp:
            self.qp.start(pos)
        passbw = 0
        mcts = strategies.MCTSSercher(self.net, qp=self.qp)
        caps = None
        while config.running:
            if pos.step>go.N*go.N*2:
                pos.result=0
                break
            c = pos.to_move
            move, values, win_rate = mcts.suggest_move(pos, forbid_pass=forbid_pass)
            self.datas.add_from_node(pos, values, win_rate)
            if self.datas.data_size>12800:
                print("保存训练数据……", end="", file=sys.stderr)
                self.datas.save()
                self.datas.clear()
                print("完毕!", file=sys.stderr)
            pos.win_rate = win_rate
            if (move is None or move == go.PASS):
                passbw |= (1 if c == go.BLACK else 2)
                if passbw == 3 or (not forbid_pass):
                    pos.play_move(go.PASS)
                    score = strategies.fast_score(pos)
                    pos.result = score / 2
                    '''msg = '%s方第%d手PASS,对局结束, %s。' % (go.get_color_str(c),
                                                     pos.step, go.result_str(score / 2))

                    if self.qp:
                        self.qp.show_message(msg=msg)'''
                    break
                else:
                    pos.play_move(go.PASS)
                    continue
            elif move == go.RESIGN:
                pos.play_move(move)
                '''msg = '%s方第%d手投子认负,对局结束, %s。' % (go.get_color_str(c),
                                               pos.step, go.result_str(pos.result))

                if self.qp:
                    self.qp.show_message(msg=msg)'''
                break
            illegal, caps = pos.play_move(move)
            if illegal == 0:
                passbw = 0
                if self.qp:
                    self.qp.update(pos)

        if sgfdir:
            dt = sgfdir + '/self_' + time.strftime('%Y-%m-%d_%H_%M_%S') + '.sgf'
            '''msg = '%.1f:\t保存sgf棋谱文件到%s' % (time.time() - policy.start_time, dt)
            if self.qp:
                self.qp.show_message(msg=msg)'''
            if not os.path.exists(sgfdir):
                os.makedirs(sgfdir)
            sgfparser.save_board(pos, dt)
        return pos
예제 #3
0
파일: processplay.py 프로젝트: zyg1968/ZiGo
def play(board=None, datas=None, queue=None, net=None, qp=None, sgfdir='sgf'):
    if not board:
        board = board.Board(komi=7.5)
    if qp:
        qp.start(board)
    passbw = 0
    mcts = strategies.MTCSSercher(net=net, qp=qp)
    caps = None
    while config.running:
        if board.step > go.N * go.N * 2:
            board.result = 0
            break
        c = board.to_move
        move, values, points = mcts.suggest_move(board)
        if board.step > go.N:
            datas.add_from_node(board, values, points)
        if datas.data_size % 1024 == 5:
            queue.put("{}-- data size: {}".format(name, data.data_size))
        if datas.data_size > 12800:
            print("保存训练数据……", end="", file=sys.stderr)
            datas.save()
            datas.clear()
            print("完毕!", file=sys.stderr)
        board.win_rate = points if c == go.BLACK else -points
        if (move is None or move == go.PASS):
            passbw |= (1 if c == go.BLACK else 2)
            if passbw == 3:
                board.play_move(go.PASS)
                score = strategies.fast_score(board)
                board.result = score / 2
                break
            else:
                board.play_move(go.PASS)
                continue
        elif move == go.RESIGN:
            board.play_move(move)
            board.result = go.N * go.N + 1 if c == go.WHITE else -go.N * go.N - 1
            '''msg = '%s方第%d手投子认负,对局结束, %s。' % (go.get_color_str(c),
                                           board.step, go.result_str(board.result))

            if qp:
                qp.show_message(msg=msg)'''
            break
        illegal, caps = board.play_move(move)
        if illegal == 0:
            passbw = 0
            if qp:
                qp.update(board)

    if sgfdir:
        dt = sgfdir + '/self_' + time.strftime('%Y-%m-%d_%H_%M_%S') + '.sgf'
        '''msg = '%.1f:\t保存sgf棋谱文件到%s' % (time.time() - policy.start_time, dt)
        if qp:
            qp.show_message(msg=msg)'''
        if not os.path.exists(sgfdir):
            os.makedirs(sgfdir)
        sgfparser.save_board(board, dt)
    return board
예제 #4
0
파일: processplay.py 프로젝트: zyg1968/ZiGo
def replay(board, data, queue, times=80, qp=None):
    final_score = strategies.fast_score(board)
    mcts = strategies.RandomPlayer()
    rboard = board.copy()
    name = mp.current_process().name
    config.running = True
    for i in range(times):
        if not config.running:
            break
        if rboard.step < 2:
            break
        rboard.undo()
        curboard = rboard.copy()
        if qp:
            qp.start(curboard)
        c = curboard.to_move
        while config.running and not curboard.is_gameover:
            move, move_values, points = mcts.suggest_move(curboard)
            curboard.win_rate = points if c == go.BLACK else -points
            data.add_from_node(curboard, move_values, points)
            if data.data_size % 1024 == 5:
                queue.put("{}-- data size: {}".format(name, data.data_size))
            if data.data_size > 12800:
                queue.put(name + "--保存训练数据……")
                data.save()
                data.clear()
                queue.put(name + "--保存完毕!")
            curboard.play_move(move)
            if qp:
                qp.update(curboard)

        score = curboard.win_rate if c == go.BLACK else -curboard.win_rate
        cscore = score - final_score
        cscore = cscore if c == go.BLACK else -cscore
        if i % 20 == 15:
            queue.put("{}--倒退{}步,全探查结果:{},{}{}了{:.1f}子, {}步".format(
                name, i + 1, go.result_str(score), go.get_color_str(c),
                "增加" if cscore > 0 else "减少", abs(cscore), curboard.step))
예제 #5
0
    def search_tree(self, node, board, depth=0):
        depth += 1
        if len(self.data) > 1000:
            print("保存到数据库,depth:",
                  depth,
                  ", self.data len:",
                  len(self.data),
                  file=sys.stderr)
            self.sql.save(self.data)
            self.data = []

        if node.search_over:
            return node.points / (node.visits if node.visits > 0 else 1)
        if depth > go.N * go.N * 2:
            node.search_over = True
            return 0
        p = self.sql.search(board)
        if p > -9999:
            node.update(p)
            return p

        move_values = [0.0 for i in range(go.N * go.N + 2)]
        tomove = board.to_move
        if node.expanded and (board.is_gameover or not node.childs):
            node.search_over = True
            score = strategies.fast_score(board)
            ps = score if tomove == go.BLACK else -score
            node.update(ps)
            self.data.append(TrainData(board, ps))
            move_values[-2] = 1.0
            self.datas.add_from_node(board, move_values, ps)
            if self.datas.data_size > 6400:
                self.datas.save()
                self.datas.clear()
            return ps

        node.points = 0.0
        for child in node.childs:
            if not config.running:
                break
            b = board.copy()
            b.play_move(child.move)
            v = 0.0
            if not child.expanded:
                self.expand(child, b)
                v = self.search_tree(child, b, depth)
            elif not child.search_over:
                v = self.search_tree(child, b.depth)
            ind = go.flatten_coords(child.move)
            move_values[ind] = v
        visits = node.visits if node.visits > 0 else 1
        #maxnode = max(node.childs, key=lambda x:x.points)
        #move_values[go.flatten_coords(maxnode.move)] = 1.0
        if config.running and depth < int(go.N * go.N * 2):
            node.search_over = True
            self.data.append(TrainData(board, node.points / visits))
            self.datas.add_from_node(board, move_values, node.points / visits)
            if self.datas.data_size > 6400:
                self.datas.save()
                #print("保存训练数据,depth:", depth, ", self.datas len:", self.datas.data_size)
                self.datas.clear()
        return node.points / visits
예제 #6
0
파일: selfplay.py 프로젝트: zyg1968/ZiGo
def test_play(net, old_net, qp=None, sgfdir=None, bw=go.BLACK):
    test_start = time.time()
    pos=board.Board(komi=7.5)
    if qp:
        qp.start(pos)
    move = None
    passbw = 0
    while config.running:
        #time.sleep(0.5)
        c = pos.to_move
        if pos.step>go.N*go.N*2:
            pos.result = 0
            msg = '双方超过722手,对局作为和棋结束。'
            if qp:
                qp.show_message(msg=msg)
            break
        if c==bw:
            move_probs, win_rate = net.run(pos)
        else:
            move_probs, win_rate = old_net.run(pos)
        pos.win_rate= win_rate*go.N*go.N
            
        move = strategies.select_most_likely(pos, move_probs, True, True)      #True:检查是否合法, True:拒绝PASS
        if move==go.RESIGN:
            pos.result = -go.N*go.N-1 if c==1 else go.N*go.N+1
            msg = '{}方第{}手认输,对局结束。{}中盘胜。'.format(go.get_color_str(c), pos.step, go.get_color_str(-c))
            if qp:
                qp.show_message(msg=msg)
            break
        if move is None or move == go.PASS:
            move = go.PASS
            passbw |= (1 if c==go.BLACK else 2)
            if passbw==3:
                pos.play_move(go.PASS)
                msg = '%s方第%d手PASS,对局结束。' % (go.get_color_str(c), pos.step)
                if qp:
                    qp.show_message(msg=msg)
                break
            else:
                pos.play_move(go.PASS)
                continue
        illegal, caps = pos.play_move(move)
        if illegal == 0:
            passbw = 0
            score = pos.score()
            if qp:
                qp.update(pos)
    if pos.step>go.N*go.N*2:
        return 0
    if abs(pos.result)<go.N*go.N:
        n = pos.step
        score = strategies.fast_score(pos)
        print("fast score: {}, n: {}/{}".format(score, n, pos.step), file=sys.stderr)
    else:
        score = pos.result*2
    score *= 1 if bw==go.BLACK else -1
    eloa = net.elo
    elob = old_net.elo
    ea=1/(1+10.0**((elob-eloa)/400))
    eb=1/(1+10.0**((eloa-elob)/400))
    win = 1 if score>0 else 0
    xs = 14.5 + 17500.0/(6000+net.elo)
    addea=xs*(win-ea)
    addeb=xs*(1-win-eb)
    print('old elo1: {:.2f}, 2: {:.2f}, add elo 1: {:.2f}, 2: {:.2f}'.format(
        eloa, elob, addea, addeb), file=sys.stderr)
    net.elo += addea
    net.elo = max(-5000, net.elo)
    old_net.elo += addeb
    old_net.elo = max(-5000, old_net.elo)
    pos.result = score / 2
    msg = '新权重:{},用时:{},{}手,{}方胜{:.1f}子, 增加{:.2f}/{:.2f}分。'.format(go.get_color_str(bw), 
        time.strftime("%Mm%Ss", time.localtime(float(time.time() - test_start))),
        pos.step+1, '新权重' if score>0 else '老权重', abs(score/2.0), addea, addeb)
    if qp:
        qp.show_message(msg=msg)
    if sgfdir:
        dt = sgfdir +'/test_' + time.strftime('%Y-%m-%d_%H_%M_%S')+'.sgf'
        msg = '%.1f:\t保存sgf棋谱文件到%s' % (time.time()-policy.start_time, dt)
        if qp:
            qp.show_message(msg=msg)
        if not os.path.exists(sgfdir):
            os.makedirs(sgfdir)
        sgfparser.save_board(pos, dt)
    return score/2