Exemplo n.º 1
0
def play_game(config: MuZeroConfig,
              network: AbstractNetwork,
              train: bool = True) -> AbstractGame:
    """
    Each game is produced by starting at the initial board position, then
    repeatedly executing a Monte Carlo Tree Search to generate moves until the end
    of the game is reached.
    """
    game = config.new_game()
    mode_action_select = 'softmax' if train else 'max'

    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.legal_actions(),
                    network.initial_inference(current_observation))
        if train:
            add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the networks.
        run_mcts(config, root, game.action_history(), network)
        action = select_action(config,
                               len(game.history),
                               root,
                               network,
                               mode=mode_action_select)
        game.apply(action)
        game.store_search_statistics(root)
    return game
Exemplo n.º 2
0
def play_game(config: MuZeroConfig,
              storage: SharedStorage,
              train: bool = True,
              visual: bool = False,
              queue: Queue = None) -> AbstractGame:
    """
    Each game is produced by starting at the initial board position, then
    repeatedly executing a Monte Carlo Tree Search to generate moves until the end
    of the game is reached.
    """
    if queue:
        network = storage.latest_network_for_process()
    else:
        network = storage.current_network

    start = time()
    game = config.new_game()
    mode_action_select = 'softmax' if train else 'max'
    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.legal_actions(),
                    network.initial_inference(current_observation))
        if train:
            add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the networks.
        run_mcts(config, root, game.action_history(), network)
        action = select_action(config,
                               len(game.history),
                               root,
                               network,
                               mode=mode_action_select)
        game.apply(action)
        game.store_search_statistics(root)
        if visual:
            game.env.render()
    if visual:
        if game.terminal():
            print('Model lost game')
        else:
            print('Exceeded max moves')
        game.env.close()

    if queue:
        queue.put(game)
    print("Finished game episode after " + str(time() - start) +
          " seconds. Exceeded max moves? " + str(not game.terminal()))
    print("Score: ", sum(game.rewards))
    return game
Exemplo n.º 3
0
def play_game(config: MuZeroConfig, network: AbstractNetwork, train: bool = True) -> AbstractGame:
  
    game = config.new_game()
    mode_action_select = 'softmax' if train else 'max'

    while not game.terminal() and len(game.history) < config.max_moves:
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation))
        if train:
            add_exploration_noise(config, root)

        run_mcts(config, root, game.action_history(), network)
        action = select_action(config, len(game.history), root, network, mode=mode_action_select)
        game.apply(action)
        game.store_search_statistics(root)
    return game