def search_minibatch(self, count, s, player, net, zugNr, zugMax, device): """ Perform several MCTS searches. """ # return: Anzahl von find_leaf calls, die Spiel bis Ende führten countEnd = 0 backup_queue = [] expand_states = [] expand_players = [] expand_queue = [] for _ in range(count): value, leaf_state, leaf_player, states, actions = self.find_leaf(s, player, zugNr, zugMax) if value is not None: countEnd += 1 backup_queue.append((value, states, actions)) else: found = False for item in expand_states: if item == leaf_state: found = True break if not found: expand_states.append(gameGo.intToB(leaf_state)) expand_players.append(leaf_player) expand_queue.append((leaf_state, states, actions)) # expansion of nodes if expand_queue: batch_v = gameGo.state_lists_to_batch(expand_states, expand_players, device) logits_v, values_v = net(batch_v) probs_v = F.softmax(logits_v, dim=1) values = values_v.data.cpu().numpy()[:,0] probs = probs_v.detach().cpu().numpy() # create the nodes for (leaf_state, states, actions), value, prob in zip(expand_queue, values, probs): self.stateStats.expand(leaf_state, prob) backup_queue.append((value, states, actions)) # backup mit value, states, actions for value, states, actions in backup_queue: # leaf state not stored in states + actions, so the value of the leaf will be the value of the opponent cur_value = -value for state, action in zip(states[::-1], actions[::-1]): self.stateStats.backup(state, action, cur_value) cur_value = -cur_value return countEnd
def search(self, count, batch_size, s, player, net, zugNr, zugMax, device): # return: Anzahl von find_leaf calls, die Spiel bis Ende führten countEnd = 0 if batch_size > 0: for _ in range(count): countEndMini = self.search_minibatch(batch_size, s, player, net, zugNr, zugMax, device) countEnd += countEndMini else: for _ in range(count): value, leaf_state, leaf_player, states, actions = self.find_leaf(s, player, zugNr, zugMax) if value is None: # expand mit leaf_state, leaf_player, states, actions batch_v = gameGo.state_lists_to_batch([gameGo.intToB(leaf_state)], [leaf_player], device) logits_v, value_v = net(batch_v) probs_v = F.softmax(logits_v, dim=1) probs = probs_v.detach().cpu().numpy()[0] value = value_v.data.cpu().numpy()[0][0] # create the node self.stateStats.expand(leaf_state, probs) else: countEnd += 1 print('Leaf bis Spielende.') cv = -value cp = leaf_player for state, action in zip(states[::-1], actions[::-1]): print('backup mit action: ', action, 'player: ', cp, ' value: ', cv, ' bei:') cv = -cv cp = 1-cp gameGo.printBrett(gameGo.intToB(state)[0]) # backup mit value, states, actions # leaf state not stored in states + actions, so the value of the leaf will be the value of the opponent cur_value = -value for state, action in zip(states[::-1], actions[::-1]): self.stateStats.backup(state, action, cur_value) cur_value = -cur_value return countEnd
def play_game(mcts_stores, replay_buffer, net1, net2, steps_before_tau_0, mcts_searches, mcts_batch_size=0, stat='nicht', device='cpu'): """ Play one single game, memorizing transitions into the replay buffer :param mcts_stores: could be single MCTS or two MCTSes for individual net :param replay_buffer: queue with (state, probs, values), if None, nothing is stored :param net1: player1 :param net2: player2 :param mcts_batch_size: Batch size for MCTS Minibatch, 0: no Minibatch Call :return: value for the game in respect to net1 (+1 if p1 won, -1 if lost, 0 if draw) Statistik: Anteil Leaf-Calls wird bei erstem Evaluate Spiel bestimmt Unterschiede MCTS vs NN wird bei letztem Evaluate Spiel bestimmt kann insg. über PLAY_STATISTIK gesteuert werden: aus, nur-Summary, detailliert """ # assert isinstance(replay_buffer, (collections.deque, type(None))) # assert isinstance(mcts_stores, (mctsGo.MCTS, type(None), list)) # assert isinstance(net1, NnGo) # assert isinstance(net2, NnGo) if isinstance(mcts_stores, mctsGo.MCTS): mcts_stores = [mcts_stores, mcts_stores] spiel = goSpielNoGraph.PlayGo(gameGo.b2Initial, zugMax=ZUG_MAX) state = spiel.bToInt() nets = [net1, net2] cur_player = 1 # schwwarz beginnt immer, und das ist net1 step = 0 countDiff = 0 countEnd = 0 countSearch = mcts_searches * mcts_batch_size if mcts_batch_size > 0 else mcts_searches tau = 1 if steps_before_tau_0 > 0 else 0 game_history = [] values, zuege = [], [] while True: statEnd = mcts_stores[1-cur_player].search(mcts_searches, mcts_batch_size, state, cur_player, nets[1-cur_player], zugNr = step+1, zugMax=ZUG_MAX, device=device) countEnd += statEnd probs = mcts_stores[1-cur_player].get_policy(state, tau=tau) game_history.append((state, cur_player, probs)) action = np.random.choice(gameGo.ANZ_POSITIONS, p=probs) if not spiel.setzZug(action): # hier move: setzen eines Zuges print('Impossible action at step ', step, ', Player: ', cur_player, '. Action=', action, ' at:') spiel.printB() print('b1:') spiel.printB(b1=True) print('mit probs:') gameGo.printBrett(probs, istFlat=True, mitFloat=True) counts = mcts_stores[1-cur_player].stateStats.b[state][0] print('Counts:') gameGo.printBrett(counts, istFlat=True) counts[action] = 0 if not spiel.setzZug(np.argmax(counts)): spiel.setzZug(81) elif PLAY_STATISTIK == 1: zuege.append(action) values.append('%1.2f ' % (mcts_stores[1-cur_player].stateStats.b[state][2][action])) if PLAY_STATISTIK > 0 and stat != 'nicht': batch_v = gameGo.state_lists_to_batch([gameGo.intToB(state)], [cur_player], device) p_v, _ = nets[1-cur_player](batch_v) probs = p_v.detach().cpu().numpy()[0] position = np.argmax(probs) if position != action: countDiff += 1 if PLAY_STATISTIK == 2: print('play_game step ', step+1, ' action Unterschied!') print('Action MCTS: ', action, ' NN: ', position) print('Anteil Leaf-Calls bis Spiel-Ende: '+str(statEnd)+' = '+str(int(statEnd*100/countSearch))+'%') print('') if spiel.spielBeendet: # print('Gewinner:', spiel.gewinner, 'S:', spiel.pktSchwarz, 'W:', spiel.pktWeiss) if PLAY_STATISTIK == 1: spiel.sgfWrite(zuege, values) if spiel.gewinner == 1: net1_result = 1 if cur_player == 1: result = 1 else: result = -1 elif spiel.gewinner == -1: net1_result = -1 if cur_player == 1: result = -1 else: result = 1 else: result = 0 net1_result = 0 break cur_player = 1-cur_player state = spiel.bToInt() step += 1 if step >= steps_before_tau_0: tau = 0 if PLAY_STATISTIK > 0: if stat == 'Diff': print('play game Unterschiede MCTS zu NN: '+str(countDiff)+' = '+str(int(countDiff*100/(step+1)))+'%') elif stat == 'Leaf': print('Anteil Leaf-Calls bis Spiel-Ende insg: ' + str(countEnd) + ' = ' + str(int(countEnd*100/(countSearch*(step+1)))) + '%') if replay_buffer is not None: for state, cur_player, probs in reversed(game_history): for drehung in (0, 90, 180, 270, 1, 2, 3, 4): replay_buffer.append((gameGo.drehB2(state, drehung), cur_player, gameGo.drehPosition(probs, drehung), result)) result = -result return net1_result, step
def trainMCTS(net, aufsatz, device): # input: Netz, Nr. des RL mit dem aufgesetzt wird (0: von Beginn an) # return Nr. des besten Netzes optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9) replay_buffer = collections.deque(maxlen=REPLAY_BUFFER) if aufsatz > 0: model = 'RL' + str(aufsatz) with open(dirSave + '/go_' + model + '.txt', "r") as infile: for line in infile: items = line.split('_') i2 = items[2][1:-1] i2 = i2.split(',') i2List = [] for i in range(gameGo.ANZ_POSITIONS): i2List.append(float(i2[i])) replay_buffer.append((int(items[0]), int(items[1]), i2List, int(items[3]))) mcts_store = mctsGo.MCTS() best_idx = aufsatz bestNet = NnGo() bestNet = bestNet.float() bestNet = copy.deepcopy(net) timeTrain = gameGo.GoTimer('trainMCTS', mitGesamt=True) timeGame = gameGo.GoTimer('play_game') prev_nodes = 0 for step_idx in range(1, MAX_STEP+1): game_steps = 0 timeTrain.start() timeGame.start() for _ in range(PLAY_EPISODES): _, steps = play_game(mcts_store, replay_buffer, bestNet, bestNet, steps_before_tau_0=STEPS_BEFORE_TAU_0, mcts_searches=MCTS_SEARCHES, mcts_batch_size= MCTS_BATCH_SIZE, device=device) game_steps += steps timeGame.stop() if step_idx % PRINT_EVERY_STEP == 0: game_nodes = len(mcts_store) - prev_nodes prev_nodes = len(mcts_store) print("Step %d, Moves last step %3d, New leaves %3d, Best net %d, Replay size %d" % ( step_idx, game_steps, game_nodes, best_idx, len(replay_buffer))) if len(replay_buffer) < MIN_REPLAY_TO_TRAIN: continue # train sum_loss = 0.0 sum_value_loss = 0.0 sum_policy_loss = 0.0 for _ in range(TRAIN_ROUNDS): batch = random.sample(replay_buffer, BATCH_SIZE) batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch) batch_states_lists = [gameGo.intToB(state) for state in batch_states] states_v = gameGo.state_lists_to_batch(batch_states_lists, batch_who_moves, device) optimizer.zero_grad() probs_v = torch.FloatTensor(batch_probs).to(device) values_v = torch.FloatTensor(batch_values).to(device) out_logits_v, out_values_v = net(states_v) loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v) loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v loss_policy_v = loss_policy_v.sum(dim=1).mean() loss_v = loss_policy_v + loss_value_v loss_v.backward() optimizer.step() sum_loss += loss_v.item() sum_value_loss += loss_value_v.item() sum_policy_loss += loss_policy_v.item() if step_idx % PRINT_EVERY_STEP == 0: lossTot = sum_loss/TRAIN_ROUNDS lossPol = sum_policy_loss/TRAIN_ROUNDS lossVal = sum_value_loss/TRAIN_ROUNDS print("loss_total: %1.2f, loss_value: %1.2f, loss_policy: %1.2f" % (lossTot, lossVal, lossPol)) writer.add_scalar("loss_total",lossTot , step_idx) writer.add_scalar("loss_value",lossVal , step_idx) writer.add_scalar("loss_policy",lossPol , step_idx) timeTrain.stop() # evaluate net if step_idx % EVALUATE_EVERY_STEP == 0: win_ratio = evaluate(net, bestNet, rounds=EVALUATION_ROUNDS, device=device) print("Net evaluated, win ratio = %.2f" % win_ratio) writer.add_scalar("eval_win_ratio", win_ratio, step_idx) if win_ratio > BEST_NET_WIN_RATIO: print("Net is better than cur best, sync") bestNet.load_state_dict(net.state_dict()) best_idx += 1 model = 'RL' + str(best_idx) torch.save(net, dirSave + '/go_' + model + '.pt') with open(dirSave + '/go_' + model + '.txt', "w") as outfile: outfile.write("\n".join(["_".join([str(a[0]), str(a[1]), str(a[2]), str(a[3])]) for a in replay_buffer])) ###showWeights(model) test(model, printNurSummary=True) mcts_store.clear() timeTrain.timerPrint() timeGame.timerPrint() return best_idx
def trainSL(net): optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9) os.chdir('/Users/kai/Documents/Python/ML/go9x9') dateien = os.listdir(".") buffer = collections.deque() # mit board2-Int, whoMoves, probs, value # teach aus einigen Spielen for datei in dateien: if datei.endswith(".sgf"): print('verarbeite '+datei) with open(datei) as f: collection = sgf.parse(f.read()) try: gewinner = collection[0].root.properties['RE'][0][0] except: print('Noch kein Gewinner') continue for drehung in (0, 90, 180, 270, 1, 2, 3, 4): spiel = goSpielNoGraph.PlayGo(gameGo.b2Initial) passM1 = False for node in collection[0].rest: for k, v in node.properties.items(): if k == 'B': player = 1 elif k == 'W': player = 0 else: print('Falscher Key, weder B noch W!') sys.exit() if len(v[0]) == 0 or v[0] == 'pj' or v[0] == 'mj': if not node.next and not passM1: position = 82 spiel.setzZug(position) else: position = 81 # pass wird nicht trainiert passM1 += 1 if passM1 <= 2: spiel.setzZug(position) elif len(v[0]) == 1: print('Falsche Zug-Syntax: ' + v[0]) sys.exit() else: passM1 = 0 reiheStr = v[0][1] if reiheStr == 'i': reiheStr = 'j' spalteStr = v[0][0] if spalteStr == 'i': spalteStr = 'j' if spalteStr not in col or reiheStr not in col: print('Falsche Zug-Syntax: ' + v[0]) sys.exit() spalte = col.index(spalteStr) reihe = col.index(reiheStr) reihe, spalte = gameGo.dreh(reihe, spalte, drehung) probs = [0] * gameGo.ANZ_POSITIONS probs[gameGo.size*reihe+spalte] = 1 if k == gewinner: buffer.append((spiel.bToInt(), player, probs, 1)) else: buffer.append((spiel.bToInt(), player, probs, -1)) position = gameGo.size * reihe + spalte if not spiel.setzZug(position): print('Unerlaubter Zug: Reihe='+str(reihe)+' Spalte='+str(spalte)) sys.exit() if spiel.gewinner == 1 and not gewinner == 'B' \ or spiel.gewinner == -1 and not gewinner == 'W': print('Bei Drehung: ', drehung, ' Gewinner im SGF inconsistent zum Spiel !') print('spiel.gewinner = '+str(spiel.gewinner)+' SGF gewinner = '+gewinner) print('PktB: ', spiel.pktSchwarz, ' PktW: ', spiel.pktWeiss) shutil.move(datei, '/Users/kai/Documents/Python/ML/go9x9test/'+datei) break # train # bei m*m Instanzen m/2 Mini-Batches nehmen batch_size = int(math.sqrt(len(buffer))/2) print('Instances: ', len(buffer), ', Mini-Batch Size: ', batch_size) for _ in range(len(buffer)//batch_size): batch = random.sample(buffer, batch_size) batch_states, batch_who_moves, batch_probs, batch_values = zip(*batch) batch_states_lists = [gameGo.intToB(state) for state in batch_states] states_v = gameGo.state_lists_to_batch(batch_states_lists, batch_who_moves) optimizer.zero_grad() probs_v = torch.FloatTensor(batch_probs) values_v = torch.FloatTensor(batch_values) out_logits_v, out_values_v = net(states_v) loss_value_v = F.mse_loss(out_values_v.squeeze(-1), values_v) loss_policy_v = -F.log_softmax(out_logits_v, dim=1) * probs_v loss_policy_v = loss_policy_v.sum(dim=1).mean() loss_v = loss_policy_v + GEWICHT_SL_MSE*loss_value_v loss_v.backward() optimizer.step() print('loss_total: ' + str(loss_v.item()) + ', loss_value: ' + str(loss_value_v.item()) + ', loss_policy: ' + str(loss_policy_v.item())) torch.save(net, dirSave+'/go_SL.pt')