def ui_main(action_file=None, ai_first=False, depth=50, breadth=10): board = Board() ai, action = None, None ai_status = GameStatus.RedMoving if ai_first else GameStatus.BlackMoving # main loop while not board.won: print_board(board) if board.status == ai_status: ai = Node(board) action = ai.search(depth, breadth) write_action(ai, action, action_file) else: try: action = read_action(action_file) except (EOFError, KeyboardInterrupt): # end of input return except: action = None if action is None: print('invalid command') continue board.apply_action(action) print_board(board) print('game over', end='') if board.status == GameStatus.RedWon: print(', red won', end='') elif board.status == GameStatus.BlackWon: print(', black won', end='') print()
def getPlayerMove(self): if self._board.is_game_over(): print("Referee told me to play but the game is over!") return "PASS" start = time.perf_counter() if self.tree is None: self.tree = Node(None, None) while time.perf_counter() - start <= 5: leaf, actions = self.tree.select() for action in actions: self._board.push(action) if not self._board.is_game_over(): leaf.expand(self._board.legal_moves()) value = int(self.rollout()) else: value = int(self._board.final_go_score()[0].lower() == Goban.Board.player_name(self._mycolor)[0]) leaf.update(value) for action in actions: self._board.pop() node, move, value, incertitude = self.tree.select_move( self._board.legal_moves()) # New here: allows to consider internal representations of moves print("I am playing ", self._board.move_to_str(move), "with score:", value, "~", incertitude) print("My current board :") self._board.prettyPrint() self._board.push(move) # move is an internal representation. To communicate with the interface I need to change if to a string return Goban.Board.flat_to_name(move)
def test_two_node_tree(self): root = Node(None, None) child = root.add_child("move") self.assertTrue(root.is_root()) self.assertFalse(root.is_leaf()) self.assertFalse(child.is_root()) self.assertTrue(child.is_leaf()) self.assertIs(child.parent(), root) self.assertEqual(child.action(), "move")
def test_expand(expand): """return true iff expand method is implemented correctly """ # initialize a blank Gomoku board gomoku_init_state = GomokuState(use_default_heuristics=True, reward_player=0) gomoku_init_node = Node(gomoku_init_state) # black makes first move init_node = Node(gomoku_init_state) black_node = init_node.add_child(GomokuAction(0, (4, 4))) black_actions = list(black_node.unused_edges) num_edges = len(black_actions) num_samples = 500 deviation = .20 white_nodes = list( [black_node.add_child(action) for action in black_actions]) # count the results of calling `expand` many times frequency_dict = {} for i in range(num_edges * num_samples): init_node = Node(gomoku_init_state) black_node = init_node.add_child(GomokuAction(0, (4, 4))) white_node = expand(black_node) if white_node not in white_nodes: print(white_node) raise ValueError("returned a Node that is not associated " "with an untried action!") if str(white_node) in frequency_dict: frequency_dict[str(white_node)] += 1 # check that expand is behaving via random selection for value in frequency_dict.values(): if abs(value - num_samples) > num_samples * deviation: raise ValueError("possible actions are not being sampled" " uniformly at randomly!") # check that exception is raised init_node = Node(gomoku_init_state) black_node = init_node.add_child(GomokuAction(0, (4, 4))) white_move_1 = black_node.add_child(GomokuAction(1, (3, 4))) white_move_2 = black_node.add_child(GomokuAction(1, (3, 5))) white_move_3 = black_node.add_child(GomokuAction(1, (4, 5))) white_move_4 = black_node.add_child(GomokuAction(1, (5, 5))) white_move_5 = black_node.add_child(GomokuAction(1, (5, 4))) white_move_6 = black_node.add_child(GomokuAction(1, (5, 3))) white_move_7 = black_node.add_child(GomokuAction(1, (4, 3))) white_move_8 = black_node.add_child(GomokuAction(1, (3, 3))) try: expand(black_node) except Exception: pass else: raise Exception("Should throw an exception for trying to expand " "a node that has already been expanded") return True
def one_run(env, n_turns, steepness, noise): env.max_turns = n_turns env.steepness = steepness env.noise_factor = noise trials = int(20 * 400 / n_turns) t = time.time() metrics_mcts_v3 = [] for i in range(trials): env.reset() m = Metric('step', 'score') root = Node(0, 10) mcts = Mcts(root) done = False while not done: action = mcts.decide() _, r, done, _ = env.step(action) mcts.register(r) for j, r in enumerate(root.results): m.add_record(j, r) metrics_mcts_v3.append(m) metrics_mcts_v3 = sum(metrics_mcts_v3) print('Time for MCTSv3:', time.time() - t) t = time.time() import random metrics_rnd = [] for i in range(trials): env.reset() m = Metric('step', 'score') rand_results = [] done = False while not done: action = random.random() * 10 _, r, done, _ = env.step(action) rand_results.append(r) for j, r in enumerate(rand_results): m.add_record(j, r) metrics_rnd.append(m) print('Time for RND:', time.time() - t) plot_group({ 'mcts_v3': metrics_mcts_v3, 'random': sum(metrics_rnd) }, 'temp', name=f'{n_turns}_st{steepness}_n{noise}')
def play_game(config: MuZeroConfig, network: Network) -> Game: game = config.new_game() while not game.terminal() and len(game.history) < config.max_moves: # create a new starting point for MCTS root = Node(0) current_observation = game.make_image(-1) root.expand_node(game.to_play(), game.legal_actions(), network.initial_inference(current_observation)) root.add_exploration_noise() # carry out the MCTS search run_mcts(config, root, game.action_history(), network) T = config.visit_softmax_temperature(num_moves=len(game.history), training_steps = network.training_steps()) # first action from the MCTS with some extra exploration action, c1 = root.select_action_with_temperature(T, epsilon = config.epsilon) game.apply(action) game.store_search_statistics(root) # continue using actions as predicted by MCTS # (minimise exploration for these) ct = 1 if not game.terminal() and ct < config.prediction_steps: action, c1 = c1.select_action_with_temperature(1) game.apply(action) game.store_search_statistics(c1) ct += 1 return game
def play_game(self, game): if self.config.fixed_temperatures is not None: self.temperature = self.config.visit_softmax_temperature( self.training_step) while not game.terminal: root = Node(0) current_observation = np.float32(game.get_observation(-1)) if self.config.norm_obs: current_observation = (current_observation - self.obs_min) / self.obs_range current_observation = torch.from_numpy(current_observation).to( self.device) initial_inference = self.network.initial_inference( current_observation.unsqueeze(0)) legal_actions = game.environment.legal_actions() root.expand(initial_inference, game.to_play, legal_actions) root.add_exploration_noise(self.config.root_dirichlet_alpha, self.config.root_exploration_fraction) self.mcts.run(root, self.network) error = root.value() - initial_inference.value.item() game.history.errors.append(error) action = self.config.select_action(root, self.temperature) game.apply(action) game.store_search_statistics(root) self.experiences_collected += 1 if self.experiences_collected % self.config.weight_sync_frequency == 0: self.sync_weights() save_history = ( game.history_idx - game.previous_collect_to) == self.config.max_history_length if save_history or game.done or game.terminal: overlap = self.config.num_unroll_steps + self.config.td_steps if not game.history.dones[game.previous_collect_to - 1]: collect_from = max(0, game.previous_collect_to - overlap) else: collect_from = game.previous_collect_to history = game.get_history_sequence(collect_from) ignore = overlap if not game.done else None self.replay_buffer.save_history.remote(history, ignore=ignore, terminal=game.terminal) if game.step >= self.config.max_steps: self.environment.was_real_done = True break if self.config.two_players: self.stats_to_log[game.info["result"]] += 1
def store_search_statistics(self, root: Node): children_nodes = root.children.values() sum_visits = sum(child.visit_count for child in children_nodes) # Total playthroughs extending from root action_space = (Action(index) for index in range(self.action_space_size)) self.child_visits.append([ root.children[a].visit_count / sum_visits if a in root.children else 0 for a in action_space ]) self.root_values.append(root.value())
def store_search_statistics(self, root: Node): sum_visits = sum(child.visit_count for child in root.children.values()) action_space = (Action(index) for index in range(self.action_space_size)) self.child_visits.append([ root.children[a].visit_count / sum_visits if a in root.children else 0 for a in action_space ]) self.root_values.append(root.value())
def move( self, move, ): if move in self.root.children: self.root = self.root.children[move] self.root.parent = None else: # new_state = copy.deepcopy(self.root.state) # new_state.place_chess(move[0], move[1]) self.root = Node(parent=None, prior_prob=1.0)
def test_backpropagate(backpropagate): """return true iff backpropagate method is implemented correctly """ init_state = GomokuState(use_default_heuristics=True, reward_player=0) init_node = Node(init_state) # assemble all of the moves black_move_0 = init_node.add_child(GomokuAction(0, (4, 4))) white_move_1 = black_move_0.add_child(GomokuAction(1, (5, 4))) black_move_2 = white_move_1.add_child(GomokuAction(1, (6, 4))) black_move_3 = white_move_1.add_child(GomokuAction(1, (5, 5))) black_move_4 = white_move_1.add_child(GomokuAction(1, (5, 3))) white_move_5 = black_move_2.add_child(GomokuAction(1, (6, 5))) white_move_6 = black_move_2.add_child(GomokuAction(1, (7, 4))) # assign values to the "terminal" moves and back-propagate backpropagate(black_move_2, 5) backpropagate(black_move_3, -1) backpropagate(black_move_4, 3) backpropagate(white_move_5, -4) backpropagate(white_move_6, 3) # check the values of the nodes in regards to num_samples and tot_reward assert_equal(black_move_0.num_samples, 5, "wrong number of samples!") assert_equal(white_move_1.num_samples, 5, "wrong number of samples!") assert_equal(black_move_2.num_samples, 3, "wrong number of samples!") assert_equal(black_move_3.num_samples, 1, "wrong number of samples!") assert_equal(black_move_4.num_samples, 1, "wrong number of samples!") assert_equal(white_move_5.num_samples, 1, "wrong number of samples!") assert_equal(white_move_6.num_samples, 1, "wrong number of samples!") assert_equal(black_move_0.tot_reward, 6, "wrong total reward!") assert_equal(white_move_1.tot_reward, 6, "wrong total reward!") assert_equal(black_move_2.tot_reward, 4, "wrong total reward!") assert_equal(black_move_3.tot_reward, -1, "wrong total reward!") assert_equal(black_move_4.tot_reward, 3, "wrong total reward!") assert_equal(white_move_5.tot_reward, -4, "wrong total reward!") assert_equal(white_move_6.tot_reward, 3, "wrong total reward!") return True
def get_next_move(game): if len(game.history) < 2: self.current_node = Node(game.state, game.player_turn, parent=None) else: opponent_move = game.history[-1] self.current_node = self.current_node.edges[opponent_move] leaf = select_leaf(self.current_node, game) rollout(leaf, game) action = select_action(self.current_node, training=False) self.current_node = self.current_node.edges[action] return action
def findBestMove(self): # Returns the best move using MonteCarlo Tree Search o = Node(self.board) b1 = (self.board.board) ## BEST Move Param bestMove = MCTS(self.maxMinutes, o, self.factor) b = copy.deepcopy(bestMove.state) b2 = (b.board) col = FindColumn(b1, b2) print("MonteCarloColumn: " + str(col)) print(b2) #SetMoveM(col) return col
def get_next_move(self, game): if len(game.history) < 2: self.current_node = Node(game.state, player_id=game.player_turn, parent=None) else: opponent_move = game.history[-1] self.current_node = self.current_node.edges[opponent_move] move = mcts_search(self.current_node, self.net, game, self.n_simulations, self.C_puct, self.dirichlet_alpha, self.training) self.current_node = self.current_node.edges[move] return move
def run_sim(): state = np.zeros((6,7)) root = Node(None, state) n = NetworkMock() mcts_agent = MCTS(ConnectXRules, n) mcts_agent.get_best_move(root, 1 , 0) for action in root.actions: print(action.visit_count, end= ' ') print() print(mcts_agent.winning_moves)
def plan(target_mol): """Generate a synthesis plan for a target molecule (in SMILES form). If a path is found, returns a list of (action, state) tuples. If a path is not found, returns None.""" root = Node(state={target_mol}) path = mcts(root, expansion, rollout, iterations=2000, max_depth=200) if path is None: print( 'No synthesis path found. Try increasing `iterations` or `max_depth`.' ) else: print('Path found:') path = [(n.action, n.state) for n in path[1:]] return path
def setUp(self): self.root = Node(n=5) self.root.parent = self.root # First level of test tree self.best_child = Node(parent=self.root, v=10, n=2) children = [self.best_child, Node(parent=self.root, v=0, n=1), Node(parent=self.root, v=-10, n=1)] self.root.children = children # Second level of test tree self.best_grandchild = Node(parent=self.best_child, v=10, n=1) best_child_children = [self.best_grandchild, Node(parent=self.best_child, v=5, n=1)] self.best_child.children = best_child_children # Third (leaf) level of test tree self.test_leaf = Node(parent=self.best_grandchild, v=1) self.best_grandchild.children = [self.test_leaf, Node(parent=self.best_grandchild)]
def expansion(node): """Try expanding each molecule in the current state to possible reactants""" # Assume each mol is a SMILES string mols = node.state # Convert mols to format for prediction # If the mol is in the starting set, ignore mols = [mol for mol in mols if mol not in starting_mols] fprs = policies.fingerprint_mols(mols) # Predict applicable rules preds = sess.run(expansion_net.pred_op, feed_dict={ expansion_net.keep_prob: 1., expansion_net.X: fprs, expansion_net.k: 5 }) # Generate children for reactants children = [] for mol, rule_idxs in zip(mols, preds): # State for children will # not include this mol new_state = mols - {mol} mol = Chem.MolFromSmiles(mol) for idx in rule_idxs: # Extract actual rule rule = expansion_rules[idx] # TODO filter_net should check if the reaction will work? # should do as a batch # Apply rule reactants = transform(mol, rule) if not reactants: continue state = new_state | set(reactants) terminal = all(mol in starting_mols for mol in state) child = Node(state=state, is_terminal=terminal, parent=node, action=rule) children.append(child) return children
def rollout(node, max_depth=200): cur = node for _ in range(max_depth): if cur.is_terminal: break # Select a random mol (that's not a starting mol) mols = [mol for mol in cur.state if mol not in starting_mols] mol = random.choice(mols) fprs = policies.fingerprint_mols([mol]) # Predict applicable rules preds = sess.run(rollout_net.pred_op, feed_dict={ expansion_net.keep_prob: 1., expansion_net.X: fprs, expansion_net.k: 1 }) rule = rollout_rules[preds[0][0]] reactants = transform(Chem.MolFromSmiles(mol), rule) state = cur.state | set(reactants) # State for children will # not include this mol state = state - {mol} terminal = all(mol in starting_mols for mol in state) cur = Node(state=state, is_terminal=terminal, parent=cur, action=rule) # Max depth exceeded else: print('Rollout reached max depth') # Partial reward if some starting molecules are found reward = sum(1 for mol in cur.state if mol in starting_mols) / len( cur.state) # Reward of -1 if no starting molecules are found if reward == 0: return -1. return reward # Reward of 1 if solution is found return 1.
def run_trials(): metrics_mcts = [] for i in range(trials): env.reset() m = Metric('step', 'score') root = Node(0, 10) mcts = Mcts(run_action, root) done = False while not done: done = mcts.step() for j, r in enumerate(root.results): m.add_record(j, r) metrics_mcts.append(m) print('Score by MCTS:', sum(root.results))
def _step_callback(self): '''Callback function for button that steps through the game world''' if not self.game.terminal: root = Node(0) current_observation = self.game.make_image(-1, self.network.device) expand_node(root, self.game.to_play(), self.game.legal_actions(), self.network.initial_inference(current_observation)) add_exploration_noise(self.config, root) # We then run a Monte Carlo Tree Search using only action sequences and the # model learned by the network. run_mcts(self.config, root, self.game.action_history(), self.network) #action = select_action(self.config, len(self.game.history), root) action = select_action(self.config, 9, root) self.game.apply(action) self.game.store_search_statistics(root) self.draw_area.draw()
def get_move(self, game): s_init = game.to_string_representation() root = Node(None, None, self.player, s_init) self.mct.root = root self.mct.root_state = deepcopy(game) start_time = time.time() while time.time() - start_time < self.timeout: leaf_node, leaf_state, path, actions = self.mct.select() turn = leaf_state.player outcome = self.mct.rollout(leaf_state, path) # print(f'{outcome} {outcome == self.mct.player}') self.mct.backprop(leaf_node, turn, outcome, path, actions) dist = self.mct.get_action_distribution() return game.LEGAL_MOVES[dist.index(max(dist))]
def expansion(node): """Try expanding each molecule in the current state to possible reactants""" # Assume each mol is a SMILES string mols = node.state # Convert mols to format for prediction mol_docs = [] for mol in mols: # If the mol is in the starting set, ignore if mol in starting_mols: continue # Preprocess for model doc = to_doc(mol) mol_docs.append((mol, doc)) # Predict reactants mols_ordered, docs = zip(*mol_docs) preds = model.sess.run(model.pred_op, feed_dict={ model.keep_prob: 1., model.X: pad_arrays(docs), model.max_decode_iter: 500, # model.beam_width: 10 }) # Generate children for reactants children = [] for mol, seqs in zip(mols_ordered, preds): # State for children will # not include this mol new_state = mols - {mol} for s in seqs: reactants, reagents = process_seq(s) # TODO should we discard reagents? # or store them on edges? state = new_state | set(reactants) terminal = all(mol in starting_mols for mol in state) child = Node(state=state, is_terminal=terminal, parent=node) children.append(child) return children
class TestMCTS(unittest.TestCase): def setUp(self) -> None: self.root = Node() def test_is_leaf(self) -> None: self.assertTrue(self.root.is_leaf()) def test_add_children(self) -> None: children = [Estate(), Duchy(), Province()] self.root.add_unique_children(children) self.assertEquals(self.root.children[0].parent, self.root) self.assertEquals(self.root.children[1].parent, self.root) self.assertEquals(self.root.children[2].parent, self.root) self.root.add_unique_children(children) self.assertEquals(len(self.root.children), len(children)) def test_get_child(self) -> None: children = [Estate(), Duchy(), Province()] self.root.add_unique_children(children) self.assertIsNotNone(self.root.get_child_node(Estate())) self.assertIsNone(self.root.get_child_node(Colony()))
def rollout(node, max_depth=200): cur = node for _ in range(max_depth): if cur.is_terminal: break # Select a random mol (that's not a starting mol) mols = [mol for mol in cur.state if mol not in starting_mols] mol = random.choice(mols) print('INPUT:', mol) # Preprocess for model doc = to_doc(mol) preds = model.sess.run(model.pred_op, feed_dict={ model.keep_prob: 1., model.X: [doc], model.max_decode_iter: 500, # model.beam_width: 1 }) seq = preds[0][0] reactants, reagents = process_seq(seq) print('OUTPUT:', set(reactants)) # TODO ignore reagents or what? state = cur.state | set(reactants) # State for children will # not include this mol state = state - {mol} terminal = all(mol in starting_mols for mol in state) cur = Node(state=state, is_terminal=terminal, parent=cur) # Max depth exceeded else: print('Rollout reached max depth') return 0. # TODO look up rewards from paper return 1.
def play_game(config, network, train): """ Each game is produced by starting at the initial board position, then repeatedly executing a Monte Carlo Tree Search to generate moves until the end of the game is reached. """ game = config.new_game() game_history = GameHistory() observation = game.reset() game_history.apply(0, observation, 0) while not game.terminal() and len( game_history.action_history) < config.max_moves: # At the root of the search tree we use the representation function to # obtain a hidden state given the current observation. root = Node(0) current_observation = game_history.make_image(-1) current_observation = torch.tensor(observation).float().unsqueeze(0) expand_node(config, root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation)) if train: add_exploration_noise(config, root) # We then run a Monte Carlo Tree Search using only action sequences and the # model learned by the networks. run_mcts(config, root, game, network) action = select_action(config, len(game_history.action_history), root, train) observation, reward = game.step(action) game_history.store_search_statistics(root, config.action_space) game_history.apply(action, observation, reward) game.close() return game_history
def play_game(config: MuZeroConfig, network: Network) -> Game: game = Game.from_config(config) while not game.terminal() and len(game.history) < config.max_moves: # At the root of the search tree we use the representation function to # obtain a hidden state given the current observation. root = Node(0) last_observation = game.make_image(-1) root.expand(game.to_play(), game.legal_actions(), network.initial_inference(last_observation).numpy()) root.add_exploration_noise(config) # logging.debug('Running MCTS on step {}.'.format(len(game.history))) # We then run a Monte Carlo Tree Search using only action sequences and the # model learned by the network. run_mcts(config, root, game.action_history(), network) action = root.select_action(config, len(game.history), network) game.apply(action) game.store_search_statistics(root) logging.info('Finished episode at step {} | cumulative reward: {}' \ .format(len(game.obs_history), sum(game.rewards))) return game
def ui_main(action_file=None, ai_first=False): board = Board() action = None ai_status = GameStatus.RedMoving if ai_first else GameStatus.BlackMoving node = Node(board) # main loop while not board.won: print_board(board) if board.status == ai_status: time_start = time.time() count = 0 # 在限定时间进行蒙特卡罗模拟 while((time.time() - time_start) < 30): count += 1 node.search() logging.debug("total count %d", count) action = node.find_best_child().action write_action(node, action, action_file) else: try: action = read_action(action_file) except (EOFError, KeyboardInterrupt): # end of input return except: action = None if action is None: print('invalid command') continue node = node.apply_action(action) board = node.status print_board(board) print('game over', end='') if board.status == GameStatus.RedWon: print(', red won', end='') elif board.status == GameStatus.BlackWon: print(', black won', end='') print()
mc_game = HexGame(SIZE, player) mc = MCTS(mc_game, MC_EXPLORATION_CONSTANT, a_net=ANET, epsilon=EPSILON) for i in tqdm(range(EPISODES + 1)): # No action needed to reach initial state action = None state = mc_game.get_simple_state() # Init Monte Carlo root root = Node(state, player, None, action, mc_game.get_reversed_binary()) while not actual_game.is_terminal_state(): if i in DISPLAY_INDICES: visualizer.draw(actual_game.get_state(), DISPLAY_DELAY) # Find the best move using MCTS new_root, prev_root_children = mc.tree_search( root, MC_NUMBER_SEARCH_GAMES) # Distribution of visit counts along all arcs emanating from root D = [ child.visits / root.visits for child in prev_root_children ] # Add case to RBUF
state = cur.state | set(reactants) # State for children will # not include this mol state = state - {mol} terminal = all(mol in starting_mols for mol in state) cur = Node(state=state, is_terminal=terminal, parent=cur) # Max depth exceeded else: print('Rollout reached max depth') return 0. # TODO look up rewards from paper return 1. # target_mol = '[H][C@@]12OC3=C(O)C=CC4=C3[C@@]11CCN(C)[C@]([H])(C4)[C@]1([H])C=C[C@@H]2O' target_mol = 'CC(=O)NC1=CC=C(O)C=C1' root = Node(state={target_mol}) path = mcts(root, expansion, rollout, iterations=2000, max_depth=200) if path is None: print('No synthesis path found. Try increasing `iterations` or `max_depth`.') else: print('Path found:') print(path) import ipdb; ipdb.set_trace()