def main(folderb, folderw, black, white): player_black, checkpoint_black = load_player(str(folderb), black) player_white, checkpoint_white = load_player(str(folderw), white) print("Loaded Black player with iteration " + str(checkpoint_black['total_ite'])) print("Loaded White player with iteration " + str(checkpoint_white['total_ite'])) ## Start method for PyTorch multiprocessing.set_start_method('spawn') evaluate(player_black, player_white)
def main(folder, ite, gtp): player, _ = load_player(folder, ite) if not isinstance(player, str): game = Game(player, 0) engine = Engine(game, board_size=game.goban_size) while True: print(engine.send(input())) elif not gtp: print(player) else: print("¯\_(ツ)_/¯")
def train(current_time, loaded_version): """ Train the models using the data generated by the self-play """ last_id = 0 total_ite = 0 lr = LR version = 1 pool = False criterion = AlphaLoss() dataset = SelfPlayDataset() ## Database connection client = MongoClient() collection = client.superGo[current_time] ## First player either from disk or fresh if loaded_version: player, checkpoint = load_player(current_time, loaded_version) optimizer = create_optimizer(player, lr, param=checkpoint['optimizer']) total_ite = checkpoint['total_ite'] lr = checkpoint['lr'] version = checkpoint['version'] last_id = collection.find().count() - (MOVES // MOVE_LIMIT) * 2 #last_id = collection.find().count() - 1 else: player = Player() optimizer = create_optimizer(player, lr) state = create_state(version, lr, total_ite, optimizer) player.save_models(state, current_time) best_player = deepcopy(player) ## Callback after the evaluation is done, must be a closure def new_agent(result): if result: nonlocal version, pending_player, current_time, \ lr, total_ite, best_player version += 1 state = create_state(version, lr, total_ite, optimizer) best_player = pending_player pending_player.save_models(state, current_time) print("[EVALUATION] New best player saved !") else: nonlocal last_id ## Force a new fetch in case the player didnt improve last_id = fetch_new_games(collection, dataset, last_id) ## Wait before the circular before is full while len(dataset) < MOVES: last_id = fetch_new_games(collection, dataset, last_id, loaded_version=loaded_version) time.sleep(30) print("[TRAIN] Circular buffer full !") print("[TRAIN] Starting to train !") dataloader = DataLoader(dataset, collate_fn=collate_fn, \ batch_size=BATCH_SIZE, shuffle=True) while True: batch_loss = [] for batch_idx, (state, move, winner) in enumerate(dataloader): running_loss = [] lr, optimizer = update_lr(lr, optimizer, total_ite) ## Evaluate a copy of the current network asynchronously if total_ite % TRAIN_STEPS == 0: if (pool): pending_player = deepcopy(player) last_id = fetch_new_games(collection, dataset, last_id) ## Wait in case an evaluation is still going on # if pool: # print("[EVALUATION] Waiting for eval to end before re-eval") # pool.close() # pool.join() pool = MyPool(1) try: pool.apply_async(evaluate, args=(pending_player, best_player), \ callback=new_agent) pool.close() pool.join() except Exception as e: client.close() pool.terminate() pool = True example = {'state': state, 'winner': winner, 'move': move} loss = train_epoch(player, optimizer, example, criterion) running_loss.append(loss) ## Print running loss if total_ite % LOSS_TICK == 0: print("[TRAIN] current iteration: %d, averaged loss: %.3f"\ % (total_ite, np.mean(running_loss))) batch_loss.append(np.mean(running_loss)) running_loss = [] ## Fetch new games if total_ite % REFRESH_TICK == 0: last_id = fetch_new_games(collection, dataset, last_id) total_ite += 1 if len(batch_loss) > 0: print("[TRAIN] Average backward pass loss : %.3f, current lr: %f" % (np.mean(batch_loss), lr))