def eval(agent:Agent, env: Checkers, color:str, n_games=100): agent.net.eval() opponent = Agent(gamma=agent.gamma, epsilon=1, lr=0, input_dims=[8*8 + 1], batch_size=agent.batch_size, action_space=agent.action_space, eps_dec=0, max_mem_size=0 ) opponent.net.eval() initial_state = env.save_state() score = {'black': 0, 'white': 0} for i in tqdm(range(n_games)): env.restore_state(initial_state) winner = None moves = torch.tensor(env.legal_moves()) board, turn, last_moved_piece = env.save_state() brain = agent if turn == color else opponent board_tensor = torch.from_numpy(env.flat_board()).view(-1).float() encoded_turn = torch.tensor([1.]) if turn == 'black' else torch.tensor([0.]) observation = torch.cat([board_tensor, encoded_turn]) while not winner: action = brain.choose_action(observation) while not action_is_legal(action, moves): action = brain.choose_action(observation) new_board, new_turn, _, moves, winner = env.move(*action.tolist()) moves = torch.tensor(moves) board_tensor = torch.from_numpy(env.flat_board()).view(-1).float() encoded_turn = torch.tensor([1. if turn == 'black' else 0.]) observation = torch.cat([board_tensor, encoded_turn]) brain = agent if turn == color else opponent score[winner] +=1 agent.net.train() return score[color] / n_games
lr=args.lr, eps_dec=args.epsilon_decay) } env = Checkers() initial_state = env.save_state() eps_history = [] score = {'black': 0, 'white': 0} os.makedirs(args.checkpoints_dir, exist_ok=True) for i in range(args.games): print( f"episode={i}, score={score}, black_eps:{players['black'].epsilon}, white_eps:{players['white'].epsilon}" ) score = {'black': 0, 'white': 0} env.restore_state(initial_state) winner = None moves = torch.tensor(env.legal_moves()) board, turn, last_moved_piece = env.save_state() brain = players[turn] board_tensor = torch.from_numpy(env.flat_board()).view(-1).float() encoded_turn = torch.tensor([1.]) if turn == 'black' else torch.tensor( [0.]) observation = torch.cat([board_tensor, encoded_turn]) while not winner: action = brain.choose_action(observation) if not action_is_legal(action, moves): reward = -1000000 new_turn = turn else: new_board, new_turn, _, moves, winner = env.move(