Exemple #1
0
 def search_minibatch(self, count, s, player, net, zugNr, zugMax, device):
     """
     Perform several MCTS searches.
     """
     # return: Anzahl von find_leaf calls, die Spiel bis Ende führten
     countEnd = 0
     backup_queue = []
     expand_states = []
     expand_players = []
     expand_queue = []
     for _ in range(count):
         value, leaf_state, leaf_player, states, actions = self.find_leaf(s, player, zugNr, zugMax)
         if value is not None:
             countEnd += 1
             backup_queue.append((value, states, actions))
         else:
             found = False
             for item in expand_states:
                 if item == leaf_state:
                     found = True
                     break
             if not found:
                 expand_states.append(gameGo.intToB(leaf_state))
                 expand_players.append(leaf_player)
                 expand_queue.append((leaf_state, states, actions))
     # expansion of nodes
     if expand_queue:
         batch_v = gameGo.state_lists_to_batch(expand_states, expand_players, device)
         logits_v, values_v = net(batch_v)
         probs_v = F.softmax(logits_v, dim=1)
         values = values_v.data.cpu().numpy()[:,0]
         probs = probs_v.detach().cpu().numpy()
         # create the nodes
         for (leaf_state, states, actions), value, prob in zip(expand_queue, values, probs):
             self.stateStats.expand(leaf_state, prob)
             backup_queue.append((value, states, actions))
     # backup mit value, states, actions
     for value, states, actions in backup_queue:
         # leaf state not stored in states + actions, so the value of the leaf will be the value of the opponent
         cur_value = -value
         for state, action in zip(states[::-1], actions[::-1]):
             self.stateStats.backup(state, action, cur_value)
             cur_value = -cur_value
     return countEnd
Exemple #2
0
 def search(self, count, batch_size, s, player, net, zugNr, zugMax, device):
     # return: Anzahl von find_leaf calls, die Spiel bis Ende führten
     countEnd = 0
     if batch_size > 0:
         for _ in range(count):
             countEndMini = self.search_minibatch(batch_size, s, player, net, zugNr, zugMax, device)
             countEnd += countEndMini
     else:
         for _ in range(count):
             value, leaf_state, leaf_player, states, actions = self.find_leaf(s, player, zugNr, zugMax)
             if value is None:
                 # expand mit leaf_state, leaf_player, states, actions
                 batch_v = gameGo.state_lists_to_batch([gameGo.intToB(leaf_state)], [leaf_player], device)
                 logits_v, value_v = net(batch_v)
                 probs_v = F.softmax(logits_v, dim=1)
                 probs = probs_v.detach().cpu().numpy()[0]
                 value = value_v.data.cpu().numpy()[0][0]
                 # create the node
                 self.stateStats.expand(leaf_state, probs)
             else:
                 countEnd += 1
                 print('Leaf bis Spielende.')
                 cv = -value
                 cp = leaf_player
                 for state, action in zip(states[::-1], actions[::-1]):
                     print('backup mit action: ', action, 'player: ', cp, ' value: ', cv, ' bei:')
                     cv = -cv
                     cp = 1-cp
                     gameGo.printBrett(gameGo.intToB(state)[0])
             # backup mit value, states, actions
             # leaf state not stored in states + actions, so the value of the leaf will be the value of the opponent
             cur_value = -value
             for state, action in zip(states[::-1], actions[::-1]):
                 self.stateStats.backup(state, action, cur_value)
                 cur_value = -cur_value
     return countEnd
Exemple #3
0
def play_game(mcts_stores, replay_buffer, net1, net2, steps_before_tau_0,
              mcts_searches, mcts_batch_size=0, stat='nicht', device='cpu'):
    """
    Play one single game, memorizing transitions into the replay buffer
    :param mcts_stores: could be single MCTS or two MCTSes for individual net
    :param replay_buffer: queue with (state, probs, values), if None, nothing is stored
    :param net1: player1
    :param net2: player2
    :param mcts_batch_size: Batch size for MCTS Minibatch, 0: no Minibatch Call
    :return: value for the game in respect to net1 (+1 if p1 won, -1 if lost, 0 if draw)
    Statistik: Anteil Leaf-Calls wird bei erstem Evaluate Spiel bestimmt
                Unterschiede MCTS vs NN wird bei letztem Evaluate Spiel bestimmt
                kann insg. über PLAY_STATISTIK gesteuert werden: aus, nur-Summary, detailliert
    """
#    assert isinstance(replay_buffer, (collections.deque, type(None)))
#    assert isinstance(mcts_stores, (mctsGo.MCTS, type(None), list))
#    assert isinstance(net1, NnGo)
#    assert isinstance(net2, NnGo)
    if isinstance(mcts_stores, mctsGo.MCTS):
        mcts_stores = [mcts_stores, mcts_stores]
    spiel = goSpielNoGraph.PlayGo(gameGo.b2Initial, zugMax=ZUG_MAX)
    state = spiel.bToInt()
    nets = [net1, net2]
    cur_player = 1 # schwwarz beginnt immer, und das ist net1
    step = 0
    countDiff = 0
    countEnd = 0
    countSearch = mcts_searches * mcts_batch_size if mcts_batch_size > 0 else mcts_searches
    tau = 1 if steps_before_tau_0 > 0 else 0
    game_history = []
    values, zuege = [], []
    while True:
        statEnd = mcts_stores[1-cur_player].search(mcts_searches, mcts_batch_size, state, cur_player,
                                        nets[1-cur_player], zugNr = step+1, zugMax=ZUG_MAX, device=device)
        countEnd += statEnd
        probs = mcts_stores[1-cur_player].get_policy(state, tau=tau)
        game_history.append((state, cur_player, probs))
        action = np.random.choice(gameGo.ANZ_POSITIONS, p=probs)
        if not spiel.setzZug(action):   # hier move: setzen eines Zuges
            print('Impossible action at step ', step, ', Player: ', cur_player, '. Action=', action, ' at:')
            spiel.printB()
            print('b1:')
            spiel.printB(b1=True)
            print('mit probs:')
            gameGo.printBrett(probs, istFlat=True, mitFloat=True)
            counts = mcts_stores[1-cur_player].stateStats.b[state][0]
            print('Counts:')
            gameGo.printBrett(counts, istFlat=True)
            counts[action] = 0
            if not spiel.setzZug(np.argmax(counts)):
                spiel.setzZug(81)
        elif PLAY_STATISTIK == 1:
            zuege.append(action)
            values.append('%1.2f ' % (mcts_stores[1-cur_player].stateStats.b[state][2][action]))
        if PLAY_STATISTIK > 0 and stat != 'nicht':
            batch_v = gameGo.state_lists_to_batch([gameGo.intToB(state)], [cur_player], device)
            p_v, _ = nets[1-cur_player](batch_v)
            probs = p_v.detach().cpu().numpy()[0]
            position = np.argmax(probs)
            if position != action:
                countDiff += 1
                if PLAY_STATISTIK == 2:
                    print('play_game step ', step+1, ' action Unterschied!')
                    print('Action  MCTS: ', action, '  NN: ', position)
                    print('Anteil Leaf-Calls bis Spiel-Ende: '+str(statEnd)+' = '+str(int(statEnd*100/countSearch))+'%')
                    print('')
        if spiel.spielBeendet:
#            print('Gewinner:', spiel.gewinner, 'S:', spiel.pktSchwarz, 'W:', spiel.pktWeiss)
            if PLAY_STATISTIK == 1:
                spiel.sgfWrite(zuege, values)
            if spiel.gewinner == 1:
                net1_result = 1
                if cur_player == 1:
                    result = 1
                else:
                    result = -1
            elif spiel.gewinner == -1:
                net1_result = -1
                if cur_player == 1:
                    result = -1
                else:
                    result = 1
            else:
                result = 0
                net1_result = 0
            break
        cur_player = 1-cur_player
        state = spiel.bToInt()
        step += 1
        if step >= steps_before_tau_0:
            tau = 0
    if PLAY_STATISTIK > 0:
        if stat == 'Diff':
            print('play game Unterschiede MCTS zu NN: '+str(countDiff)+' = '+str(int(countDiff*100/(step+1)))+'%')
        elif stat == 'Leaf':
            print('Anteil Leaf-Calls bis Spiel-Ende insg: '
              + str(countEnd) + ' = ' + str(int(countEnd*100/(countSearch*(step+1)))) + '%')
    if replay_buffer is not None:
        for state, cur_player, probs in reversed(game_history):
            for drehung in (0, 90, 180, 270, 1, 2, 3, 4):
                replay_buffer.append((gameGo.drehB2(state, drehung), cur_player,
                                      gameGo.drehPosition(probs, drehung), result))
            result = -result
    return net1_result, step
Exemple #4
0
def trainMCTS(net, aufsatz, device):
    # input: Netz, Nr. des RL mit dem aufgesetzt wird (0: von Beginn an)
    # return Nr. des besten Netzes
    optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9)
    replay_buffer = collections.deque(maxlen=REPLAY_BUFFER)
    if aufsatz > 0:
        model = 'RL' + str(aufsatz)
        with open(dirSave + '/go_' + model + '.txt', "r") as infile:
            for line in infile:
                items = line.split('_')
                i2 = items[2][1:-1]
                i2 = i2.split(',')
                i2List = []
                for i in range(gameGo.ANZ_POSITIONS):
                    i2List.append(float(i2[i]))
                replay_buffer.append((int(items[0]), int(items[1]), i2List, int(items[3])))
    mcts_store = mctsGo.MCTS()
    best_idx = aufsatz
    bestNet = NnGo()
    bestNet = bestNet.float()
    bestNet = copy.deepcopy(net)
    timeTrain = gameGo.GoTimer('trainMCTS', mitGesamt=True)
    timeGame = gameGo.GoTimer('play_game')
    prev_nodes = 0
    for step_idx in range(1, MAX_STEP+1):
        game_steps = 0
        timeTrain.start()
        timeGame.start()
        for _ in range(PLAY_EPISODES):
            _, steps = play_game(mcts_store, replay_buffer, bestNet, bestNet,
                        steps_before_tau_0=STEPS_BEFORE_TAU_0, mcts_searches=MCTS_SEARCHES,
                        mcts_batch_size= MCTS_BATCH_SIZE, device=device)
            game_steps += steps
        timeGame.stop()
        if step_idx % PRINT_EVERY_STEP == 0:
            game_nodes = len(mcts_store) - prev_nodes
            prev_nodes = len(mcts_store)
            print("Step %d, Moves last step %3d, New leaves %3d, Best net %d, Replay size %d" % (
                step_idx, game_steps, game_nodes, best_idx, len(replay_buffer)))
        if len(replay_buffer) < MIN_REPLAY_TO_TRAIN:
            continue
        # train
        sum_loss = 0.0
        sum_value_loss = 0.0
        sum_policy_loss = 0.0
        for _ in range(TRAIN_ROUNDS):
            batch = random.sample(replay_buffer, BATCH_SIZE)
            batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch)
            batch_states_lists = [gameGo.intToB(state) for state in batch_states]
            states_v = gameGo.state_lists_to_batch(batch_states_lists, batch_who_moves, device)
            optimizer.zero_grad()
            probs_v = torch.FloatTensor(batch_probs).to(device)
            values_v = torch.FloatTensor(batch_values).to(device)
            out_logits_v, out_values_v = net(states_v)

            loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
            loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
            loss_policy_v = loss_policy_v.sum(dim=1).mean()

            loss_v = loss_policy_v + loss_value_v
            loss_v.backward()
            optimizer.step()
            sum_loss += loss_v.item()
            sum_value_loss += loss_value_v.item()
            sum_policy_loss += loss_policy_v.item()
        if step_idx % PRINT_EVERY_STEP == 0:
            lossTot = sum_loss/TRAIN_ROUNDS
            lossPol = sum_policy_loss/TRAIN_ROUNDS
            lossVal = sum_value_loss/TRAIN_ROUNDS
            print("loss_total: %1.2f, loss_value: %1.2f, loss_policy: %1.2f" % (lossTot, lossVal, lossPol))
            writer.add_scalar("loss_total",lossTot , step_idx)
            writer.add_scalar("loss_value",lossVal , step_idx)
            writer.add_scalar("loss_policy",lossPol , step_idx)
        timeTrain.stop()
        # evaluate net
        if step_idx % EVALUATE_EVERY_STEP == 0:
            win_ratio = evaluate(net, bestNet, rounds=EVALUATION_ROUNDS, device=device)
            print("Net evaluated, win ratio = %.2f" % win_ratio)
            writer.add_scalar("eval_win_ratio", win_ratio, step_idx)
            if win_ratio > BEST_NET_WIN_RATIO:
                print("Net is better than cur best, sync")
                bestNet.load_state_dict(net.state_dict())
                best_idx += 1
                model = 'RL' + str(best_idx)
                torch.save(net, dirSave + '/go_' + model + '.pt')
                with open(dirSave + '/go_' + model + '.txt', "w") as outfile:
                    outfile.write("\n".join(["_".join([str(a[0]), str(a[1]), str(a[2]), str(a[3])])
                                             for a in replay_buffer]))
                ###showWeights(model)
                test(model, printNurSummary=True)
                mcts_store.clear()
    timeTrain.timerPrint()
    timeGame.timerPrint()
    return best_idx
Exemple #5
0
def trainSL(net):
    optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9)
    os.chdir('/Users/kai/Documents/Python/ML/go9x9')
    dateien = os.listdir(".")
    buffer = collections.deque()    # mit board2-Int, whoMoves, probs, value
    # teach aus einigen Spielen
    for datei in dateien:
        if datei.endswith(".sgf"):
            print('verarbeite '+datei)
            with open(datei) as f:
                collection = sgf.parse(f.read())
            try:
                gewinner = collection[0].root.properties['RE'][0][0]
            except:
                print('Noch kein Gewinner')
                continue
            for drehung in (0, 90, 180, 270, 1, 2, 3, 4):
                spiel = goSpielNoGraph.PlayGo(gameGo.b2Initial)
                passM1 = False
                for node in collection[0].rest:
                    for k, v in node.properties.items():
                        if k == 'B':
                            player = 1
                        elif k == 'W':
                            player = 0
                        else:
                            print('Falscher Key, weder B noch W!')
                            sys.exit()
                        if len(v[0]) == 0 or v[0] == 'pj' or v[0] == 'mj':
                            if not node.next and not passM1:
                                position = 82
                                spiel.setzZug(position)
                            else:
                                position = 81
                                # pass wird nicht trainiert
                                passM1 += 1
                                if passM1 <= 2:
                                    spiel.setzZug(position)
                        elif len(v[0]) == 1:
                            print('Falsche Zug-Syntax: ' + v[0])
                            sys.exit()
                        else:
                            passM1 = 0
                            reiheStr = v[0][1]
                            if reiheStr == 'i':
                                reiheStr = 'j'
                            spalteStr = v[0][0]
                            if spalteStr == 'i':
                                spalteStr = 'j'
                            if spalteStr not in col or reiheStr not in col:
                                print('Falsche Zug-Syntax: ' + v[0])
                                sys.exit()
                            spalte = col.index(spalteStr)
                            reihe = col.index(reiheStr)
                            reihe, spalte = gameGo.dreh(reihe, spalte, drehung)
                            probs = [0] * gameGo.ANZ_POSITIONS
                            probs[gameGo.size*reihe+spalte] = 1
                            if k == gewinner:
                                buffer.append((spiel.bToInt(), player, probs, 1))
                            else:
                                buffer.append((spiel.bToInt(), player, probs, -1))
                            position = gameGo.size * reihe + spalte
                            if not spiel.setzZug(position):
                                print('Unerlaubter Zug: Reihe='+str(reihe)+' Spalte='+str(spalte))
                                sys.exit()
                if spiel.gewinner == 1 and not gewinner == 'B' \
                        or spiel.gewinner == -1 and not gewinner == 'W':
                    print('Bei Drehung: ', drehung, ' Gewinner im SGF inconsistent zum Spiel !')
                    print('spiel.gewinner = '+str(spiel.gewinner)+' SGF gewinner = '+gewinner)
                    print('PktB: ', spiel.pktSchwarz, ' PktW: ', spiel.pktWeiss)
                    shutil.move(datei, '/Users/kai/Documents/Python/ML/go9x9test/'+datei)
                    break
    # train
    # bei m*m Instanzen m/2 Mini-Batches nehmen
    batch_size = int(math.sqrt(len(buffer))/2)
    print('Instances: ', len(buffer), ', Mini-Batch Size: ', batch_size)
    for _ in range(len(buffer)//batch_size):
        batch = random.sample(buffer, batch_size)
        batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch)
        batch_states_lists = [gameGo.intToB(state) for state in batch_states]
        states_v = gameGo.state_lists_to_batch(batch_states_lists, batch_who_moves)
        optimizer.zero_grad()
        probs_v = torch.FloatTensor(batch_probs)
        values_v = torch.FloatTensor(batch_values)
        out_logits_v, out_values_v = net(states_v)

        loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v)
        loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v
        loss_policy_v = loss_policy_v.sum(dim=1).mean()

        loss_v = loss_policy_v + GEWICHT_SL_MSE*loss_value_v
        loss_v.backward()
        optimizer.step()
        print('loss_total: ' + str(loss_v.item()) + ', loss_value: ' + str(loss_value_v.item())
              + ', loss_policy: ' + str(loss_policy_v.item()))
    torch.save(net, dirSave+'/go_SL.pt')