def test_insert_some_coins(): b = Board() assert b.turn() == 'O' b = b.insert(3) assert b.turn() == 'X' assert b == Board([0, 0, 0, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'O' assert b == Board([0, 0, 0b10, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'X' assert b == Board([0, 0, 0b0110, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'O' assert b == Board([0, 0, 0b100110, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'X' assert b == Board([0, 0, 0b01100110, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'O' assert b == Board([0, 0, 0b1001100110, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 2, 3, 4, 5, 6) b = b.insert(2) assert b.turn() == 'X' assert b == Board([0, 0, 0b011001100110, 0b01, 0, 0, 0]) assert b.valid_actions() == (0, 1, 3, 4, 5, 6)
def select_action(policy, board: Board, cuda=False, noise=0): # Get probabilities from neural network state = torch.from_numpy(board.matrix().reshape(BOARD_ROWS * BOARD_COLS)).float().unsqueeze(0) if cuda: state = state.cuda() probs = policy(Variable(state)) # Exclude any results that are not allowed mult_np = np.zeros(len(POSSIBLE_ACTIONS), dtype=np.float32) allowed_actions = board.valid_actions() for i in POSSIBLE_ACTIONS: if i in allowed_actions: mult_np[i] = 1 # Always choose winning move for a in allowed_actions: hypothetical_board = board.insert(a) if hypothetical_board.winner() == board.turn(): mult_np = np.zeros(len(POSSIBLE_ACTIONS), dtype=np.float32) mult_np[a] = 1 mult = Variable(torch.from_numpy(mult_np)) noise = Variable(torch.from_numpy(mult_np * noise)) if cuda: mult = mult.cuda() noise = noise.cuda() probs = probs * mult + noise if torch.sum(probs * mult).data[0] < 1e-40: # Neural network only offered things that are not allowed, so we go for random probs = probs + mult return probs.multinomial()