Esempio n. 1
0
def ui_main(action_file=None, ai_first=False, depth=50, breadth=10):
    board = Board()
    ai, action = None, None
    ai_status = GameStatus.RedMoving if ai_first else GameStatus.BlackMoving
    # main loop
    while not board.won:
        print_board(board)
        if board.status == ai_status:
            ai = Node(board)
            action = ai.search(depth, breadth)
            write_action(ai, action, action_file)
        else:
            try:
                action = read_action(action_file)
            except (EOFError, KeyboardInterrupt):
                # end of input
                return
            except:
                action = None
            if action is None:
                print('invalid command')
                continue
        board.apply_action(action)
    print_board(board)
    print('game over', end='')
    if board.status == GameStatus.RedWon:
        print(', red won', end='')
    elif board.status == GameStatus.BlackWon:
        print(', black won', end='')
    print()
Esempio n. 2
0
    def getPlayerMove(self):
        if self._board.is_game_over():
            print("Referee told me to play but the game is over!")
            return "PASS"
        start = time.perf_counter()

        if self.tree is None:
            self.tree = Node(None, None)
        while time.perf_counter() - start <= 5:
            leaf, actions = self.tree.select()
            for action in actions:
                self._board.push(action)
            if not self._board.is_game_over():
                leaf.expand(self._board.legal_moves())
                value = int(self.rollout())
            else:
                value = int(self._board.final_go_score()[0].lower() ==
                            Goban.Board.player_name(self._mycolor)[0])
            leaf.update(value)
            for action in actions:
                self._board.pop()

        node, move, value, incertitude = self.tree.select_move(
            self._board.legal_moves())
        # New here: allows to consider internal representations of moves
        print("I am playing ", self._board.move_to_str(move), "with score:",
              value, "~", incertitude)
        print("My current board :")
        self._board.prettyPrint()

        self._board.push(move)
        # move is an internal representation. To communicate with the interface I need to change if to a string
        return Goban.Board.flat_to_name(move)
 def test_two_node_tree(self):
     root = Node(None, None)
     child = root.add_child("move")
     self.assertTrue(root.is_root())
     self.assertFalse(root.is_leaf())
     self.assertFalse(child.is_root())
     self.assertTrue(child.is_leaf())
     self.assertIs(child.parent(), root)
     self.assertEqual(child.action(), "move")
Esempio n. 4
0
def test_expand(expand):
    """return true iff expand method is implemented correctly
    """
    # initialize a blank Gomoku board
    gomoku_init_state = GomokuState(use_default_heuristics=True,
                                    reward_player=0)
    gomoku_init_node = Node(gomoku_init_state)

    # black makes first move
    init_node = Node(gomoku_init_state)
    black_node = init_node.add_child(GomokuAction(0, (4, 4)))
    black_actions = list(black_node.unused_edges)
    num_edges = len(black_actions)
    num_samples = 500
    deviation = .20
    white_nodes = list(
        [black_node.add_child(action) for action in black_actions])

    # count the results of calling `expand` many times
    frequency_dict = {}
    for i in range(num_edges * num_samples):
        init_node = Node(gomoku_init_state)
        black_node = init_node.add_child(GomokuAction(0, (4, 4)))
        white_node = expand(black_node)
        if white_node not in white_nodes:
            print(white_node)
            raise ValueError("returned a Node that is not associated "
                             "with an untried action!")
        if str(white_node) in frequency_dict:
            frequency_dict[str(white_node)] += 1
    # check that expand is behaving via random selection
    for value in frequency_dict.values():
        if abs(value - num_samples) > num_samples * deviation:
            raise ValueError("possible actions are not being sampled"
                             " uniformly at randomly!")

    # check that exception is raised
    init_node = Node(gomoku_init_state)
    black_node = init_node.add_child(GomokuAction(0, (4, 4)))
    white_move_1 = black_node.add_child(GomokuAction(1, (3, 4)))
    white_move_2 = black_node.add_child(GomokuAction(1, (3, 5)))
    white_move_3 = black_node.add_child(GomokuAction(1, (4, 5)))
    white_move_4 = black_node.add_child(GomokuAction(1, (5, 5)))
    white_move_5 = black_node.add_child(GomokuAction(1, (5, 4)))
    white_move_6 = black_node.add_child(GomokuAction(1, (5, 3)))
    white_move_7 = black_node.add_child(GomokuAction(1, (4, 3)))
    white_move_8 = black_node.add_child(GomokuAction(1, (3, 3)))
    try:
        expand(black_node)
    except Exception:
        pass
    else:
        raise Exception("Should throw an exception for trying to expand "
                        "a node that has already been expanded")
    return True
Esempio n. 5
0
def one_run(env, n_turns, steepness, noise):

    env.max_turns = n_turns
    env.steepness = steepness
    env.noise_factor = noise

    trials = int(20 * 400 / n_turns)

    t = time.time()
    metrics_mcts_v3 = []
    for i in range(trials):
        env.reset()
        m = Metric('step', 'score')
        root = Node(0, 10)
        mcts = Mcts(root)

        done = False
        while not done:
            action = mcts.decide()
            _, r, done, _ = env.step(action)
            mcts.register(r)

        for j, r in enumerate(root.results):
            m.add_record(j, r)

        metrics_mcts_v3.append(m)

    metrics_mcts_v3 = sum(metrics_mcts_v3)
    print('Time for MCTSv3:', time.time() - t)



    t = time.time()
    import random
    metrics_rnd = []
    for i in range(trials):

        env.reset()
        m = Metric('step', 'score')
        rand_results = []
        done = False
        while not done:
            action = random.random() * 10
            _, r, done, _ = env.step(action)
            rand_results.append(r)

        for j, r in enumerate(rand_results):
            m.add_record(j, r)

        metrics_rnd.append(m)

    print('Time for RND:', time.time() - t)

    plot_group({
        'mcts_v3': metrics_mcts_v3,
        'random': sum(metrics_rnd)
    },
        'temp', name=f'{n_turns}_st{steepness}_n{noise}')
Esempio n. 6
0
def play_game(config: MuZeroConfig, network: Network) -> Game:

    game = config.new_game()

    while not game.terminal() and len(game.history) < config.max_moves:

        # create a new starting point for MCTS
        root = Node(0)
        current_observation = game.make_image(-1)

        root.expand_node(game.to_play(), game.legal_actions(),
                        network.initial_inference(current_observation)) 
        root.add_exploration_noise()

        # carry out the MCTS search
        run_mcts(config, root, game.action_history(), network)
        
        T = config.visit_softmax_temperature(num_moves=len(game.history), training_steps = network.training_steps())

        # first action from the MCTS with some extra exploration
        action, c1 = root.select_action_with_temperature(T, epsilon = config.epsilon) 
        game.apply(action)
        game.store_search_statistics(root) 
        
        # continue using actions as predicted by MCTS
        # (minimise exploration for these)
        ct = 1
        if not game.terminal() and ct < config.prediction_steps:
            action, c1 = c1.select_action_with_temperature(1) 
            game.apply(action)
            game.store_search_statistics(c1)
            ct += 1
        
    return game
Esempio n. 7
0
    def play_game(self, game):

        if self.config.fixed_temperatures is not None:
            self.temperature = self.config.visit_softmax_temperature(
                self.training_step)

        while not game.terminal:
            root = Node(0)

            current_observation = np.float32(game.get_observation(-1))
            if self.config.norm_obs:
                current_observation = (current_observation -
                                       self.obs_min) / self.obs_range
            current_observation = torch.from_numpy(current_observation).to(
                self.device)

            initial_inference = self.network.initial_inference(
                current_observation.unsqueeze(0))

            legal_actions = game.environment.legal_actions()
            root.expand(initial_inference, game.to_play, legal_actions)
            root.add_exploration_noise(self.config.root_dirichlet_alpha,
                                       self.config.root_exploration_fraction)

            self.mcts.run(root, self.network)

            error = root.value() - initial_inference.value.item()
            game.history.errors.append(error)

            action = self.config.select_action(root, self.temperature)

            game.apply(action)
            game.store_search_statistics(root)

            self.experiences_collected += 1

            if self.experiences_collected % self.config.weight_sync_frequency == 0:
                self.sync_weights()

            save_history = (
                game.history_idx -
                game.previous_collect_to) == self.config.max_history_length
            if save_history or game.done or game.terminal:
                overlap = self.config.num_unroll_steps + self.config.td_steps
                if not game.history.dones[game.previous_collect_to - 1]:
                    collect_from = max(0, game.previous_collect_to - overlap)
                else:
                    collect_from = game.previous_collect_to
                history = game.get_history_sequence(collect_from)
                ignore = overlap if not game.done else None
                self.replay_buffer.save_history.remote(history,
                                                       ignore=ignore,
                                                       terminal=game.terminal)

            if game.step >= self.config.max_steps:
                self.environment.was_real_done = True
                break

        if self.config.two_players:
            self.stats_to_log[game.info["result"]] += 1
Esempio n. 8
0
 def store_search_statistics(self, root: Node):
   children_nodes = root.children.values()
   sum_visits = sum(child.visit_count for child in children_nodes) # Total playthroughs extending from root
   action_space = (Action(index) for index in range(self.action_space_size))
   self.child_visits.append([
       root.children[a].visit_count / sum_visits if a in root.children else 0
       for a in action_space
   ])
   self.root_values.append(root.value())
Esempio n. 9
0
    def store_search_statistics(self, root: Node):

        sum_visits = sum(child.visit_count for child in root.children.values())
        action_space = (Action(index)
                        for index in range(self.action_space_size))
        self.child_visits.append([
            root.children[a].visit_count /
            sum_visits if a in root.children else 0 for a in action_space
        ])
        self.root_values.append(root.value())
Esempio n. 10
0
 def move(
     self,
     move,
 ):
     if move in self.root.children:
         self.root = self.root.children[move]
         self.root.parent = None
     else:
         # new_state = copy.deepcopy(self.root.state)
         # new_state.place_chess(move[0], move[1])
         self.root = Node(parent=None, prior_prob=1.0)
Esempio n. 11
0
def test_backpropagate(backpropagate):
    """return true iff backpropagate method is implemented correctly
    """
    init_state = GomokuState(use_default_heuristics=True, reward_player=0)
    init_node = Node(init_state)

    # assemble all of the moves
    black_move_0 = init_node.add_child(GomokuAction(0, (4, 4)))
    white_move_1 = black_move_0.add_child(GomokuAction(1, (5, 4)))
    black_move_2 = white_move_1.add_child(GomokuAction(1, (6, 4)))
    black_move_3 = white_move_1.add_child(GomokuAction(1, (5, 5)))
    black_move_4 = white_move_1.add_child(GomokuAction(1, (5, 3)))
    white_move_5 = black_move_2.add_child(GomokuAction(1, (6, 5)))
    white_move_6 = black_move_2.add_child(GomokuAction(1, (7, 4)))

    # assign values to the "terminal" moves and back-propagate
    backpropagate(black_move_2, 5)
    backpropagate(black_move_3, -1)
    backpropagate(black_move_4, 3)
    backpropagate(white_move_5, -4)
    backpropagate(white_move_6, 3)

    # check the values of the nodes in regards to num_samples and tot_reward
    assert_equal(black_move_0.num_samples, 5, "wrong number of samples!")
    assert_equal(white_move_1.num_samples, 5, "wrong number of samples!")
    assert_equal(black_move_2.num_samples, 3, "wrong number of samples!")
    assert_equal(black_move_3.num_samples, 1, "wrong number of samples!")
    assert_equal(black_move_4.num_samples, 1, "wrong number of samples!")
    assert_equal(white_move_5.num_samples, 1, "wrong number of samples!")
    assert_equal(white_move_6.num_samples, 1, "wrong number of samples!")
    assert_equal(black_move_0.tot_reward, 6, "wrong total reward!")
    assert_equal(white_move_1.tot_reward, 6, "wrong total reward!")
    assert_equal(black_move_2.tot_reward, 4, "wrong total reward!")
    assert_equal(black_move_3.tot_reward, -1, "wrong total reward!")
    assert_equal(black_move_4.tot_reward, 3, "wrong total reward!")
    assert_equal(white_move_5.tot_reward, -4, "wrong total reward!")
    assert_equal(white_move_6.tot_reward, 3, "wrong total reward!")

    return True
Esempio n. 12
0
    def get_next_move(game):
        if len(game.history) < 2:
            self.current_node = Node(game.state, game.player_turn, parent=None)
        else:
            opponent_move = game.history[-1]
            self.current_node = self.current_node.edges[opponent_move]

        leaf = select_leaf(self.current_node, game)
        rollout(leaf, game)
        action = select_action(self.current_node, training=False)
        self.current_node = self.current_node.edges[action]

        return action
Esempio n. 13
0
 def findBestMove(self):
     # Returns the best move using MonteCarlo Tree Search
     o = Node(self.board)
     b1 = (self.board.board)
     ## BEST Move Param
     bestMove = MCTS(self.maxMinutes, o, self.factor)
     b = copy.deepcopy(bestMove.state)
     b2 = (b.board)
     col = FindColumn(b1, b2)
     print("MonteCarloColumn: " + str(col))
     print(b2)
     #SetMoveM(col)
     return col
Esempio n. 14
0
    def get_next_move(self, game):
        if len(game.history) < 2:
            self.current_node = Node(game.state,
                                     player_id=game.player_turn,
                                     parent=None)
        else:
            opponent_move = game.history[-1]
            self.current_node = self.current_node.edges[opponent_move]
        move = mcts_search(self.current_node, self.net, game,
                           self.n_simulations, self.C_puct,
                           self.dirichlet_alpha, self.training)
        self.current_node = self.current_node.edges[move]

        return move
Esempio n. 15
0
def run_sim():
    state = np.zeros((6,7))
    root = Node(None, state)
    n = NetworkMock()

    mcts_agent = MCTS(ConnectXRules, n)
    
    mcts_agent.get_best_move(root, 1 , 0)
    

    for action in root.actions:
        print(action.visit_count, end= ' ')
    print()
    print(mcts_agent.winning_moves)
Esempio n. 16
0
def plan(target_mol):
    """Generate a synthesis plan for a target molecule (in SMILES form).
    If a path is found, returns a list of (action, state) tuples.
    If a path is not found, returns None."""
    root = Node(state={target_mol})

    path = mcts(root, expansion, rollout, iterations=2000, max_depth=200)
    if path is None:
        print(
            'No synthesis path found. Try increasing `iterations` or `max_depth`.'
        )
    else:
        print('Path found:')
        path = [(n.action, n.state) for n in path[1:]]
    return path
    def setUp(self):
        self.root = Node(n=5)
        self.root.parent = self.root

        # First level of test tree
        self.best_child = Node(parent=self.root, v=10, n=2)
        children = [self.best_child, Node(parent=self.root, v=0, n=1), Node(parent=self.root, v=-10, n=1)]
        self.root.children = children

        # Second level of test tree
        self.best_grandchild = Node(parent=self.best_child, v=10, n=1)
        best_child_children = [self.best_grandchild, Node(parent=self.best_child, v=5, n=1)]
        self.best_child.children = best_child_children

        # Third (leaf) level of test tree
        self.test_leaf = Node(parent=self.best_grandchild, v=1)
        self.best_grandchild.children = [self.test_leaf, Node(parent=self.best_grandchild)]
Esempio n. 18
0
def expansion(node):
    """Try expanding each molecule in the current state
    to possible reactants"""

    # Assume each mol is a SMILES string
    mols = node.state

    # Convert mols to format for prediction
    # If the mol is in the starting set, ignore
    mols = [mol for mol in mols if mol not in starting_mols]
    fprs = policies.fingerprint_mols(mols)

    # Predict applicable rules
    preds = sess.run(expansion_net.pred_op,
                     feed_dict={
                         expansion_net.keep_prob: 1.,
                         expansion_net.X: fprs,
                         expansion_net.k: 5
                     })

    # Generate children for reactants
    children = []
    for mol, rule_idxs in zip(mols, preds):
        # State for children will
        # not include this mol
        new_state = mols - {mol}

        mol = Chem.MolFromSmiles(mol)
        for idx in rule_idxs:
            # Extract actual rule
            rule = expansion_rules[idx]

            # TODO filter_net should check if the reaction will work?
            # should do as a batch

            # Apply rule
            reactants = transform(mol, rule)

            if not reactants: continue

            state = new_state | set(reactants)
            terminal = all(mol in starting_mols for mol in state)
            child = Node(state=state,
                         is_terminal=terminal,
                         parent=node,
                         action=rule)
            children.append(child)
    return children
Esempio n. 19
0
def rollout(node, max_depth=200):
    cur = node
    for _ in range(max_depth):
        if cur.is_terminal:
            break

        # Select a random mol (that's not a starting mol)
        mols = [mol for mol in cur.state if mol not in starting_mols]
        mol = random.choice(mols)
        fprs = policies.fingerprint_mols([mol])

        # Predict applicable rules
        preds = sess.run(rollout_net.pred_op,
                         feed_dict={
                             expansion_net.keep_prob: 1.,
                             expansion_net.X: fprs,
                             expansion_net.k: 1
                         })

        rule = rollout_rules[preds[0][0]]
        reactants = transform(Chem.MolFromSmiles(mol), rule)
        state = cur.state | set(reactants)

        # State for children will
        # not include this mol
        state = state - {mol}

        terminal = all(mol in starting_mols for mol in state)
        cur = Node(state=state, is_terminal=terminal, parent=cur, action=rule)

    # Max depth exceeded
    else:
        print('Rollout reached max depth')

        # Partial reward if some starting molecules are found
        reward = sum(1 for mol in cur.state if mol in starting_mols) / len(
            cur.state)

        # Reward of -1 if no starting molecules are found
        if reward == 0:
            return -1.

        return reward

    # Reward of 1 if solution is found
    return 1.
Esempio n. 20
0
def run_trials():
    metrics_mcts = []

    for i in range(trials):
        env.reset()
        m = Metric('step', 'score')
        root = Node(0, 10)
        mcts = Mcts(run_action, root)

        done = False
        while not done:
            done = mcts.step()

        for j, r in enumerate(root.results):
            m.add_record(j, r)

        metrics_mcts.append(m)
        print('Score by MCTS:', sum(root.results))
Esempio n. 21
0
    def _step_callback(self):
        '''Callback function for button that steps through the game world'''
        if not self.game.terminal:
            root = Node(0)
            current_observation = self.game.make_image(-1, self.network.device)
            expand_node(root, self.game.to_play(), self.game.legal_actions(),
                        self.network.initial_inference(current_observation))
            add_exploration_noise(self.config, root)

            # We then run a Monte Carlo Tree Search using only action sequences and the
            # model learned by the network.
            run_mcts(self.config, root, self.game.action_history(),
                     self.network)
            #action = select_action(self.config, len(self.game.history), root)
            action = select_action(self.config, 9, root)
            self.game.apply(action)
            self.game.store_search_statistics(root)
        self.draw_area.draw()
Esempio n. 22
0
    def get_move(self, game):

        s_init = game.to_string_representation()
        root = Node(None, None, self.player, s_init)
        self.mct.root = root
        self.mct.root_state = deepcopy(game)

        start_time = time.time()

        while time.time() - start_time < self.timeout:
            leaf_node, leaf_state, path, actions = self.mct.select()
            turn = leaf_state.player
            outcome = self.mct.rollout(leaf_state, path)
            # print(f'{outcome} {outcome == self.mct.player}')
            self.mct.backprop(leaf_node, turn, outcome, path, actions)

        dist = self.mct.get_action_distribution()
        return game.LEGAL_MOVES[dist.index(max(dist))]
Esempio n. 23
0
def expansion(node):
    """Try expanding each molecule in the current state
    to possible reactants"""

    # Assume each mol is a SMILES string
    mols = node.state

    # Convert mols to format for prediction
    mol_docs = []
    for mol in mols:
        # If the mol is in the starting set, ignore
        if mol in starting_mols:
            continue

        # Preprocess for model
        doc = to_doc(mol)
        mol_docs.append((mol, doc))

    # Predict reactants
    mols_ordered, docs = zip(*mol_docs)
    preds = model.sess.run(model.pred_op, feed_dict={
        model.keep_prob: 1.,
        model.X: pad_arrays(docs),
        model.max_decode_iter: 500,
        # model.beam_width: 10
    })

    # Generate children for reactants
    children = []
    for mol, seqs in zip(mols_ordered, preds):
        # State for children will
        # not include this mol
        new_state = mols - {mol}

        for s in seqs:
            reactants, reagents = process_seq(s)
            # TODO should we discard reagents?
            # or store them on edges?

            state = new_state | set(reactants)
            terminal = all(mol in starting_mols for mol in state)
            child = Node(state=state, is_terminal=terminal, parent=node)
            children.append(child)
    return children
Esempio n. 24
0
class TestMCTS(unittest.TestCase):
    def setUp(self) -> None:
        self.root = Node()

    def test_is_leaf(self) -> None:
        self.assertTrue(self.root.is_leaf())

    def test_add_children(self) -> None:
        children = [Estate(), Duchy(), Province()]
        self.root.add_unique_children(children)

        self.assertEquals(self.root.children[0].parent, self.root)
        self.assertEquals(self.root.children[1].parent, self.root)
        self.assertEquals(self.root.children[2].parent, self.root)

        self.root.add_unique_children(children)
        self.assertEquals(len(self.root.children), len(children))

    def test_get_child(self) -> None:
        children = [Estate(), Duchy(), Province()]
        self.root.add_unique_children(children)

        self.assertIsNotNone(self.root.get_child_node(Estate()))
        self.assertIsNone(self.root.get_child_node(Colony()))
Esempio n. 25
0
def rollout(node, max_depth=200):
    cur = node
    for _ in range(max_depth):
        if cur.is_terminal:
            break

        # Select a random mol (that's not a starting mol)
        mols = [mol for mol in cur.state if mol not in starting_mols]
        mol = random.choice(mols)
        print('INPUT:', mol)

        # Preprocess for model
        doc = to_doc(mol)

        preds = model.sess.run(model.pred_op, feed_dict={
            model.keep_prob: 1.,
            model.X: [doc],
            model.max_decode_iter: 500,
            # model.beam_width: 1
        })
        seq = preds[0][0]
        reactants, reagents = process_seq(seq)
        print('OUTPUT:', set(reactants))

        # TODO ignore reagents or what?

        state = cur.state | set(reactants)

        # State for children will
        # not include this mol
        state = state - {mol}

        terminal = all(mol in starting_mols for mol in state)
        cur = Node(state=state, is_terminal=terminal, parent=cur)

    # Max depth exceeded
    else:
        print('Rollout reached max depth')
        return 0.

    # TODO look up rewards from paper
    return 1.
def play_game(config, network, train):
    """
    Each game is produced by starting at the initial board position, then
    repeatedly executing a Monte Carlo Tree Search to generate moves until the end
    of the game is reached.
    """
    game = config.new_game()

    game_history = GameHistory()
    observation = game.reset()
    game_history.apply(0, observation, 0)

    while not game.terminal() and len(
            game_history.action_history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game_history.make_image(-1)
        current_observation = torch.tensor(observation).float().unsqueeze(0)

        expand_node(config, root, game.to_play(), game.legal_actions(),
                    network.initial_inference(current_observation))
        if train:
            add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the networks.
        run_mcts(config, root, game, network)
        action = select_action(config, len(game_history.action_history), root,
                               train)

        observation, reward = game.step(action)
        game_history.store_search_statistics(root, config.action_space)
        game_history.apply(action, observation, reward)

    game.close()

    return game_history
Esempio n. 27
0
def play_game(config: MuZeroConfig, network: Network) -> Game:
    game = Game.from_config(config)

    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        last_observation = game.make_image(-1)
        root.expand(game.to_play(), game.legal_actions(),
                    network.initial_inference(last_observation).numpy())
        root.add_exploration_noise(config)

        # logging.debug('Running MCTS on step {}.'.format(len(game.history)))
        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the network.
        run_mcts(config, root, game.action_history(), network)
        action = root.select_action(config, len(game.history), network)
        game.apply(action)
        game.store_search_statistics(root)

    logging.info('Finished episode at step {} | cumulative reward: {}' \
        .format(len(game.obs_history), sum(game.rewards)))

    return game
Esempio n. 28
0
def ui_main(action_file=None, ai_first=False):
    board = Board()
    action = None
    ai_status = GameStatus.RedMoving if ai_first else GameStatus.BlackMoving
    node = Node(board)
    # main loop
    while not board.won:
        print_board(board)
        if board.status == ai_status:
            time_start = time.time()
            count = 0
            # 在限定时间进行蒙特卡罗模拟
            while((time.time() - time_start) < 30):
                count += 1
                node.search()
            logging.debug("total count %d", count)
            action = node.find_best_child().action
            write_action(node, action, action_file)
        else:
            try:
                action = read_action(action_file)
            except (EOFError, KeyboardInterrupt):
                # end of input
                return
            except:
                action = None
            if action is None:
                print('invalid command')
                continue
        node = node.apply_action(action)
        board = node.status
    print_board(board)
    print('game over', end='')
    if board.status == GameStatus.RedWon:
        print(', red won', end='')
    elif board.status == GameStatus.BlackWon:
        print(', black won', end='')
    print()
Esempio n. 29
0
        mc_game = HexGame(SIZE, player)
        mc = MCTS(mc_game,
                  MC_EXPLORATION_CONSTANT,
                  a_net=ANET,
                  epsilon=EPSILON)

        for i in tqdm(range(EPISODES + 1)):

            # No action needed to reach initial state
            action = None

            state = mc_game.get_simple_state()

            # Init Monte Carlo root
            root = Node(state, player, None, action,
                        mc_game.get_reversed_binary())

            while not actual_game.is_terminal_state():
                if i in DISPLAY_INDICES:
                    visualizer.draw(actual_game.get_state(), DISPLAY_DELAY)

                # Find the best move using MCTS
                new_root, prev_root_children = mc.tree_search(
                    root, MC_NUMBER_SEARCH_GAMES)

                # Distribution of visit counts along all arcs emanating from root
                D = [
                    child.visits / root.visits for child in prev_root_children
                ]

                # Add case to RBUF
Esempio n. 30
0
        state = cur.state | set(reactants)

        # State for children will
        # not include this mol
        state = state - {mol}

        terminal = all(mol in starting_mols for mol in state)
        cur = Node(state=state, is_terminal=terminal, parent=cur)

    # Max depth exceeded
    else:
        print('Rollout reached max depth')
        return 0.

    # TODO look up rewards from paper
    return 1.


# target_mol = '[H][C@@]12OC3=C(O)C=CC4=C3[C@@]11CCN(C)[C@]([H])(C4)[C@]1([H])C=C[C@@H]2O'
target_mol = 'CC(=O)NC1=CC=C(O)C=C1'
root = Node(state={target_mol})

path = mcts(root, expansion, rollout, iterations=2000, max_depth=200)
if path is None:
    print('No synthesis path found. Try increasing `iterations` or `max_depth`.')
else:
    print('Path found:')
    print(path)
    import ipdb; ipdb.set_trace()