示例#1
0
    def test_empty_input(self):
        """
        Tests the following scenarios:
         - Assert that observation tensors with only zeros are encoded to finite values (can be zero)
         - Assert that latent state tensors with only zeros are transitioned to finite values (can be zero)
        """
        # Build the environment for an observation.
        s = self.g.getInitialState()
        o_t = self.g.buildObservation(s,
                                      player=1,
                                      form=self.g.Representation.HEURISTIC)
        h = GameHistory()

        # Build empty observations
        h.capture(o_t, -1, 1, np.array([]), 0, 0)
        stacked = h.stackObservations(self.net.net_args.observation_length,
                                      o_t)
        zeros_like = np.zeros_like(stacked)

        # Check if nans are produced
        latent, _, _ = self.net.initial_inference(zeros_like)
        self.assertTrue(np.isfinite(latent).all())

        # Exhaustively ensure that all possible dynamics function inputs lead to finite values.
        latent_forwards = [
            self.net.recurrent_inference(latent, action)[1]
            for action in range(self.g.getActionSize())
        ]
        self.assertTrue(np.isfinite(np.array(latent_forwards)).all())
示例#2
0
class Player(ABC):
    def __init__(self,
                 game,
                 arg_file: typing.Optional[str] = None,
                 name: str = "",
                 parametric: bool = False) -> None:
        self.game = game
        self.player_args = arg_file
        self.parametric = parametric
        self.histories = list()
        self.history = GameHistory()
        self.name = name

    def bind_history(self, history: GameHistory) -> None:
        self.history = history

    def refresh(self, hard_reset: bool = False) -> None:
        if hard_reset:
            self.histories = list()
            self.history.refresh()
        else:
            self.histories.append(self.history)
            self.history = GameHistory()

    def observe(self, state: GameState) -> None:
        self.history.capture(state, np.array([]), 0, 0)

    def clone(self):
        return self.__class__(self.game, self.player_args)

    @abstractmethod
    def act(self, state: GameState) -> int:
        """
示例#3
0
 def refresh(self, hard_reset: bool = False) -> None:
     if hard_reset:
         self.histories = list()
         self.history.refresh()
     else:
         self.histories.append(self.history)
         self.history = GameHistory()
示例#4
0
    def learn(self) -> None:
        """
        Control the data gathering and weight optimization loop. Perform 'num_selfplay_iterations' iterations
        of self-play to gather data, each of 'num_episodes' episodes. After every self-play iteration, train the
        neural network with the accumulated data. If specified, the previous neural network weights are evaluated
        against the newly fitted neural network weights, the newly fitted weights are then accepted based on some
        specified win/ lose ratio. Neural network weights and the replay buffer are stored after every iteration.
        Note that for highly granular vision based environments, that the replay buffer may grow to large sizes.
        """
        for i in range(1, self.args.num_selfplay_iterations + 1):
            print(f'------ITER {i}------')
            if not self.update_on_checkpoint or i > 1:  # else: go directly to backpropagation

                # Self-play/ Gather training data.
                iteration_train_examples = list()
                for _ in trange(self.args.num_episodes, desc="Self Play", file=sys.stdout):
                    self.mcts.clear_tree()
                    iteration_train_examples.append(self.executeEpisode())

                    if sum(map(len, iteration_train_examples)) > self.args.max_buffer_size:
                        iteration_train_examples.pop(0)

                # Store data from previous self-play iterations into the history.
                self.trainExamplesHistory.append(iteration_train_examples)

            # Print out statistics about the replay buffer, and back-up the data history to a file (can be slow).
            GameHistory.print_statistics(self.trainExamplesHistory)
            self.saveTrainExamples(i - 1)

            # Flatten examples over self-play episodes and sample a training batch.
            complete_history = GameHistory.flatten(self.trainExamplesHistory)

            # Training new network, keeping a copy of the old one
            self.neural_net.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')

            # Backpropagation
            for _ in trange(self.args.num_gradient_steps, desc="Backpropagation", file=sys.stdout):
                batch = self.sampleBatch(complete_history)

                self.neural_net.train(batch)
                self.neural_net.monitor.log_batch(batch)

            # Pitting
            accept = True
            if self.args.pitting:
                # Load in the old network.
                self.opponent_net.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')

                # Perform trials with the new network against the old network.
                arena = Arena(self.game, self.arena_player, self.arena_opponent, self.args.max_trial_moves)
                accept = arena.pitting(self.args, self.neural_net.monitor)

            if accept:
                print('ACCEPTING NEW MODEL')
                self.neural_net.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i))
                self.neural_net.save_checkpoint(folder=self.args.checkpoint, filename=self.args.load_folder_file[-1])
            else:
                print('REJECTING NEW MODEL')
                self.neural_net.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
示例#5
0
文件: Player.py 项目: windpipe/muzero
 def refresh(self, hard_reset: bool = False) -> None:
     """ Refresh or reinitialize memory/ observation trajectory of the agent. """
     if hard_reset:
         self.histories = list()
         self.history.refresh()
     else:
         self.histories.append(self.history)
         self.history = GameHistory()
示例#6
0
    def test_search_recursion_error(self):
        """
        The main phenomenon this test attempts to find is:
        Let s be the current latent state, s = [0, 0, 0], along with action a = 1.
        If we fetch the next latent state with (s, a) we do not want to get, s' == s = [0, 0, 0].
        s' is a new state, although it is present in the transition table due to being identical to s.
        if action a = 1 is chosen again by UCB, then this could result in infinite recursion.

        Tests the following scenarios:
         - Assert that MuMCTS does not result in a recursion error when called with the same
           input multiple times without clearing the tree.
         - Assert that MuMCTS does not result in a recursion error when inputs are either zero
           or random.
         - Assert that MuMCTS does not result in a recursion error when only one root action is legal.
        """
        rep = 30  # Repetition factor --> should be high.

        # Build the environment for an observation.
        s = self.g.getInitialState()
        o_t = self.g.buildObservation(s,
                                      player=1,
                                      form=self.g.Representation.HEURISTIC)
        h = GameHistory()

        # Build empty and random observations tensors
        h.capture(o_t, -1, 1, np.array([]), 0, 0)
        stacked = h.stackObservations(self.net.net_args.observation_length,
                                      o_t)
        zeros_like = np.zeros_like(stacked)
        random_like = np.random.rand(*zeros_like.shape)

        # Build root state legal action masks
        legals = np.ones(self.g.getActionSize())
        same = np.zeros_like(legals)
        same[0] = 1  # Can only do one move

        # Execute multiple MCTS runs that will result in recurring tree paths.
        for _ in range(rep):
            self.mcts.runMCTS(
                zeros_like, legals)  # Empty observations ALL moves at the root
        self.mcts.clear_tree()

        for _ in range(rep):
            self.mcts.runMCTS(zeros_like,
                              same)  # Empty observations ONE move at the root
        self.mcts.clear_tree()

        for _ in range(rep):
            self.mcts.runMCTS(
                random_like,
                legals)  # Empty observations ALL moves at the root
        self.mcts.clear_tree()

        for _ in range(rep):
            self.mcts.runMCTS(random_like,
                              same)  # Empty observations ONE move at the root
        self.mcts.clear_tree()
示例#7
0
 def __init__(self,
              game,
              arg_file: typing.Optional[str] = None,
              name: str = "",
              parametric: bool = False) -> None:
     self.game = game
     self.player_args = arg_file
     self.parametric = parametric
     self.histories = list()
     self.history = GameHistory()
     self.name = name
示例#8
0
    def initialize_root(self, state: GameState,
                        trajectory: GameHistory) -> typing.Tuple[bytes, float]:
        """
        Perform initial inference for the root state and perturb the network prior with Dirichlet noise.
        Additionally mask the illegal moves in the network prior and initialize all statistics for starting the
        MCTS search.
        :param state: GameState Data structure containing the current state of the environment.
        :param trajectory: GameHistory Data structure containing the entire episode trajectory of the agent(s).
        :return: tuple (hash, root_value) The hash of the environment state and inferred root-value.
        """
        network_input = trajectory.stackObservations(
            self.neural_net.net_args.observation_length, state.observation)
        pi_0, v_0 = self.neural_net.predict(network_input)

        s_0 = self.game.getHash(state)

        # Add Dirichlet Exploration noise
        noise = np.random.dirichlet([self.args.dirichlet_alpha] * len(pi_0))
        self.Ps[s_0] = noise * self.args.exploration_fraction + (
            1 - self.args.exploration_fraction) * pi_0

        # Mask the prior for illegal moves, and re-normalize accordingly.
        self.Vs[s_0] = self.game.getLegalMoves(state)

        self.Ps[s_0] *= self.Vs[s_0]
        self.Ps[s_0] = self.Ps[s_0] / np.sum(self.Ps[s_0])

        # Sum of visit counts of the edges/ children and legal moves.
        self.Ns[s_0] = 0

        return s_0, v_0
示例#9
0
    def initialize_root(self, state: GameState, trajectory: GameHistory) -> typing.Tuple[typing.Tuple[bytes, tuple],
                                                                                         np.ndarray, float]:
        """
        Embed the provided root state into the MuZero Neural Network. Additionally perform inference for
        this root state and perturb the network prior with Dirichlet noise. As we have access to the game at
        this state, we mask the initial prior with the legal moves at this state.
        :param state: GameState Data structure containing the current state of the environment.
        :param trajectory: GameHistory Data structure containing the entire episode trajectory of the agent(s).
        :return: tuple (hash, latent_state, root_value) The hash/ data of the latent state and inferred root-value.
        """
        # Perform initial inference on o_t-l, ... o_t
        o_t = self.game.buildObservation(state)
        stacked_observations = trajectory.stackObservations(self.neural_net.net_args.observation_length, o_t)
        latent_state, pi_0, v_0 = self.neural_net.initial_inference(stacked_observations)

        s_0 = (latent_state.tobytes(), tuple())  # Hashable representation

        # Add Dirichlet Exploration noise
        noise = np.random.dirichlet([self.args.dirichlet_alpha] * len(pi_0))
        self.Ps[s_0] = noise * self.args.exploration_fraction + (1 - self.args.exploration_fraction) * pi_0

        # Mask the prior for illegal moves, and re-normalize accordingly.
        self.Vs[s_0] = self.game.getLegalMoves(state)

        self.Ps[s_0] *= self.Vs[s_0]
        self.Ps[s_0] = self.Ps[s_0] / np.sum(self.Ps[s_0])

        # Sum of visit counts of the edges/ children
        self.Ns[s_0] = 0

        return s_0, latent_state, v_0
示例#10
0
    def test_search_border_cases_latent_state(self):
        """
        Tests the following scenarios:
        - Assert that observation tensors with only infinities or nans result in finite tensors (zeros).
          Testing this phenomenon ensures that bad input is not propagated for more than one step.
          Note that one forward step using bad inputs can already lead to a recursion error in MuMCTS.
          see test_search_recursion_error
       """
        # Build the environment for an observation.
        s = self.g.getInitialState()
        o_t = self.g.buildObservation(s,
                                      player=1,
                                      form=self.g.Representation.HEURISTIC)
        h = GameHistory()

        # Build empty observations
        h.capture(o_t, -1, 1, np.array([]), 0, 0)
        stacked = h.stackObservations(self.net.net_args.observation_length,
                                      o_t)
        nans_like = np.zeros_like(stacked)
        inf_like = np.zeros_like(stacked)

        nans_like[nans_like == 0] = np.nan
        inf_like[inf_like == 0] = np.inf

        # Check if nans are produced
        nan_latent, _, _ = self.net.initial_inference(nans_like)
        inf_latent, _, _ = self.net.initial_inference(inf_like)

        self.assertTrue(np.isfinite(nan_latent).all())
        self.assertTrue(np.isfinite(inf_latent).all())

        nan_latent[nan_latent == 0] = np.nan
        inf_latent[inf_latent == 0] = np.inf

        # Exhaustively ensure that all possible dynamics function inputs lead to finite values.
        nan_latent_forwards = [
            self.net.recurrent_inference(nan_latent, action)[1]
            for action in range(self.g.getActionSize())
        ]
        inf_latent_forwards = [
            self.net.recurrent_inference(inf_latent, action)[1]
            for action in range(self.g.getActionSize())
        ]

        self.assertTrue(np.isfinite(np.array(nan_latent_forwards)).all())
        self.assertTrue(np.isfinite(np.array(inf_latent_forwards)).all())
示例#11
0
    def buildHypotheticalSteps(self, history: GameHistory, t: int, k: int) -> \
            typing.Tuple[np.ndarray, typing.Tuple[np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
        """
        Sample/ extrapolate a sequence of targets for unrolling/ fitting the MuZero neural network.

        This sequence consists of the actions performed at time t until t + k - 1. These are used for unrolling the
        dynamics model. For extrapolating beyond terminal states we adopt an uniform policy over the entire action
        space to ensure that the model learns to generalize over the actions when encountering terminal states.

        The move-probabilities, value, and reward predictions are sampled from t until t + k. Note that the reward
        at the first index is not used for weight optimization as the initial call to the model does not predict
        rewards. For extrapolating beyond terminal states we repeat a zero vector for the move-probabilities and
        zeros for the reward and value targets seeing as a terminated environment does not provide rewards. The
        zero vector for the move-probabilities is used to define an improper probability distribution. The loss
        function can then infer that the episode ended, and distribute gradient accordingly.

        Empirically we observed that extrapolating an uniform move-policy for the move-probability vector results
        in slower and more unstable learning as we're feeding wrong data to the neural networks. We found that not
        distributing any gradient at all to these extrapolated steps resulted in the best learning.

        :param history: GameHistory Sampled data structure containing all statistics/ observations of a finished game.
        :param t: int The sampled index to generate the targets at.
        :param k: int The number of unrolling steps to perform/ length of the dynamics model target sequence.
        :return: Tuple of (actions, targets, future_inputs) that the neural network needs for optimization
        """
        # One hot encode actions.
        actions = history.actions[t:t + k]
        a_truncation = k - len(actions)
        if a_truncation > 0:  # Uniform policy when unrolling beyond terminal states.
            actions += np.random.randint(self.game.getActionSize(),
                                         size=a_truncation).tolist()

        enc_actions = np.zeros([k, self.game.getActionSize()])
        enc_actions[np.arange(len(actions)), actions] = 1

        # Value targets.
        pis = history.probabilities[t:t + k + 1]
        vs = history.observed_returns[t:t + k + 1]
        rewards = history.rewards[t:t + k + 1]

        # Handle truncations > 0 due to terminal states. Treat last state as absorbing state
        t_truncation = (k + 1) - len(
            pis)  # Target truncation due to terminal state
        if t_truncation > 0:
            pis += [np.zeros_like(pis[-1])] * t_truncation  # Zero vector
            rewards += [0] * t_truncation  # = 0
            vs += [0] * t_truncation  # = 0

        # If specified, also sample/ extrapolate future observations. Otherwise return an empty array.
        obs_trajectory = []
        if self.return_forward_observations:
            obs_trajectory = [
                history.stackObservations(self.observation_stack_length,
                                          t=t + i + 1) for i in range(k)
            ]

        # (Actions, Targets, Observations)
        return enc_actions, (np.asarray(vs), np.asarray(rewards),
                             np.asarray(pis)), obs_trajectory
示例#12
0
文件: Player.py 项目: windpipe/muzero
class Player(ABC):
    """ Interface for players for general environment control/ game playing. """
    def __init__(self,
                 game,
                 arg_file: typing.Optional[str] = None,
                 name: str = "",
                 parametric: bool = False) -> None:
        """
        Initialization of the Base Player object.
        :param game: Game Instance of Games.Game that implements environment logic.
        :param arg_file: str Path to JSON configuration file for the agent/ player.
        :param name: str Name to annotate this player with (useful during tournaments)
        :param parametric: bool Whether the agent depends on parameters or is parameter-free.
        """
        self.game = game
        self.player_args = arg_file
        self.parametric = parametric
        self.histories = list()
        self.history = GameHistory()
        self.name = name

    def bind_history(self, history: GameHistory) -> None:
        """ Bind an external memory object for keeping track of environment observations. """
        self.history = history

    def refresh(self, hard_reset: bool = False) -> None:
        """ Refresh or reinitialize memory/ observation trajectory of the agent. """
        if hard_reset:
            self.histories = list()
            self.history.refresh()
        else:
            self.histories.append(self.history)
            self.history = GameHistory()

    def observe(self, state: GameState) -> None:
        """ Capture an environment state observation within the agent's memory. """
        self.history.capture(state, np.array([]), 0, 0)

    def clone(self):
        """ Create a new instance of this Player object using equivalent parameterization """
        return self.__class__(self.game, self.player_args)

    @abstractmethod
    def act(self, state: GameState) -> int:
        """
示例#13
0
文件: Player.py 项目: windpipe/muzero
 def __init__(self,
              game,
              arg_file: typing.Optional[str] = None,
              name: str = "",
              parametric: bool = False) -> None:
     """
     Initialization of the Base Player object.
     :param game: Game Instance of Games.Game that implements environment logic.
     :param arg_file: str Path to JSON configuration file for the agent/ player.
     :param name: str Name to annotate this player with (useful during tournaments)
     :param parametric: bool Whether the agent depends on parameters or is parameter-free.
     """
     self.game = game
     self.player_args = arg_file
     self.parametric = parametric
     self.histories = list()
     self.history = GameHistory()
     self.name = name
示例#14
0
    def executeEpisode(self) -> GameHistory:
        """
        Perform one episode of self-play for gathering data to train neural networks on.

        The implementation details of the neural networks/ agents, temperature schedule, data storage
        is kept highly transparent on this side of the algorithm. Hence for implementation details
        see the specific implementations of the function calls.

        At every step we record a snapshot of the state into a GameHistory object, this includes the observation,
        MCTS search statistics, performed action, and observed rewards. After the end of the episode, we close the
        GameHistory object and compute internal target values.

        :return: GameHistory Data structure containing all observed states and statistics required for network training.
        """
        history = GameHistory()
        state = self.game.getInitialState(
        )  # Always from perspective of player 1 for boardgames.
        step = 0

        while not state.done and step < self.args.max_episode_moves:
            if debugging.RENDER:  # Display visualization of the environment if specified.
                self.game.render(state)

            # Update MCTS visit count temperature according to an episode or weight update schedule.
            temp = self.update_temperature(
                self.neural_net.steps if self.temp_schedule.args.
                by_weight_update else step)

            # Compute the move probability vector and state value using MCTS for the current state of the environment.
            pi, v = self.mcts.runMCTS(state, history, temp=temp)

            # Take a step in the environment and observe the transition and store necessary statistics.
            state.action = np.random.choice(len(pi), p=pi)
            next_state, r = self.game.getNextState(state, state.action)
            history.capture(state, pi, r, v)

            # Update state of control
            state = next_state
            step += 1

        # Cleanup environment and GameHistory
        self.game.close(state)
        history.terminate()
        history.compute_returns(
            gamma=self.args.gamma,
            n=(self.args.n_steps if self.game.n_players == 1 else None))

        return history
示例#15
0
    def test_n_step_return_estimation_MDP(self):
        horizon = 3  # n-step lookahead for computing z_t
        gamma = 1 / 2  # discount factor for future rewards and bootstrap

        # Experiment settings
        search_results = [5, 5, 5, 5, 5]  # MCTS v_t index +k
        dummy_rewards = [0, 1, 2, 3, 4]  # u_t+1 index +k
        z = 0  # Final return provided by the env.

        # Desired output: Correct z_t index +k (calculated manually)
        n_step = [1 + 5 / 8, 3 + 3 / 8, 4 + 1 / 2, 5.0, 4.0, 0]

        # Fill the GameHistory with the required data.
        h = GameHistory()
        for r, v in zip(dummy_rewards, search_results):
            h.capture(np.array([0]), -1, 1, np.array([0]), r, v)
        h.terminate(np.asarray([]), 1, z)

        # Check if algorithm computes z_t's correctly
        h.compute_returns(gamma, horizon)
        np.testing.assert_array_almost_equal(h.observed_returns[:-1], n_step[:-1])
示例#16
0
    def test_observation_stacking(self):
        # random normal variables in the form (x, y, c)
        shape = (3, 3, 8)
        dummy_observations = [np.random.randn(np.prod(shape)).reshape(shape) for _ in range(10)]

        h = GameHistory()
        h.capture(dummy_observations[0], -1, 1, np.array([]), 0, 0)

        # Ensure correct shapes and content
        stacked_0 = h.stackObservations(0)
        stacked_1 = h.stackObservations(1)
        stacked_5 = h.stackObservations(5)

        np.testing.assert_array_equal(stacked_0.shape, shape)  # Shape didn't change
        np.testing.assert_array_equal(stacked_1.shape, shape)  # Shape didn't change
        np.testing.assert_array_equal(stacked_5.shape, np.array(shape) * np.array([1, 1, 5]))  # Channels * 5

        np.testing.assert_array_almost_equal(stacked_0, dummy_observations[0])  # Should be the same
        np.testing.assert_array_almost_equal(stacked_1, dummy_observations[0])  # Should be the same

        np.testing.assert_array_almost_equal(stacked_5[..., :-8], 0)  # Should be only 0s
        np.testing.assert_array_almost_equal(stacked_5[..., -8:], dummy_observations[0])  # Should be the first o_t

        # Check whether observation concatenation works correctly
        stacked = h.stackObservations(2, dummy_observations[1])
        expected = np.concatenate(dummy_observations[:2], axis=-1)
        np.testing.assert_array_almost_equal(stacked, expected)

        # Fill the buffer
        all([h.capture(x, -1, 1, np.array([]), 0, 0) for x in dummy_observations[1:]])

        # Check whether time indexing works correctly
        stacked_1to5 = h.stackObservations(4, t=4)  # 1-4 --> t is inclusive
        stacked_last4 = h.stackObservations(4, t=9)  # 6-9
        expected_1to5 = np.concatenate(dummy_observations[1:5], axis=-1)  # t in {1, 2, 3, 4}
        expected_last4 = np.concatenate(dummy_observations[-4:], axis=-1)  # t in {6, 7, 8, 9}

        np.testing.assert_array_almost_equal(stacked_1to5, expected_1to5)
        np.testing.assert_array_almost_equal(stacked_last4, expected_last4)

        # Check if clearing works correctly
        h.refresh()
        self.assertEqual(len(h), 0)
示例#17
0
    def _search(
        self,
        state: GameState,
        trajectory: GameHistory,
        path: typing.Tuple[int, ...] = tuple()) -> float:
        """
        Recursively perform MCTS search inside the actual environments with search-paths guided by the PUCT formula.

        Selection chooses an action for expanding/ traversing the edge (s, a) within the tree search.
        The exploration_factor for the PUCT formula is computed within this function for efficiency:

            exploration_factor = c1 * log(visits_s + c2 + 1) - log(c2)

        Setting AlphaMCTS.CANONICAL to true sets exploration_factor just to c1.

        If an edge is expanded, we perform a step within the environment (with action a) and observe the state
        transition, reward, and infer the new move probabilities, and state value. If an edge is traversed, we simply
        look up earlier inferred/ observed values from the class dictionaries.

        During backup we update the current value estimates of an edge Q(s, a) using an average, we additionally
        update the MinMax statistics to get reward/ value boundaries for the PUCT formula. Note that backed-up
        values get discounted for gamma < 1. For adversarial games, we negate the backed up value G_k at each backup.

        The actual search-path 'path' is kept as a debugging-variable, it currently has no practical use. This method
        may raise a recursion error if the environment creates cycles, this should be highly improbable for most
        environments. If this does occur, the environment can be altered to terminate after n visits to some cycle.

        :param state: GameState Numerical prediction of the state by the encoder/ dynamics model.
        :param trajectory: GameHistory Data structure containing all observations until the current search-depth.
        :param path: tuple of integers representing the tree search-path of the current function call.
        :return: float The backed-up discounted/ Monte-Carlo returns (dependent on gamma) of the tree search.
        :raises RecursionError: When cycles occur within the search path, the search can get stuck *ad infinitum*.
        """
        s = self.game.getHash(state)

        ### SELECTION
        # pick the action with the highest upper confidence bound
        exploration_factor = self.args.c1 + np.log(self.Ns[s] + self.args.c2 +
                                                   1) - np.log(self.args.c2)
        confidence_bounds = np.asarray([
            self.compute_ucb(s, a, exploration_factor)
            for a in range(self.action_size)
        ])
        a = np.flatnonzero(self.Vs[s])[np.argmax(
            confidence_bounds[self.Vs[s].astype(bool)])]  # Get masked argmax.

        # Default leaf node value. Future possible future reward is 0. Variable is overwritten if edge is non-terminal.
        value = 0
        if (s, a) not in self.Ssa:  ### ROLLOUT for valid moves
            next_state, reward = self.game.getNextState(state, a, clone=True)
            s_next = self.game.getHash(next_state)

            # Transition statistics.
            self.Rsa[(s, a)], self.Ssa[(
                s, a)], self.Ns[s_next] = reward, next_state, 0

            # Inference for non-terminal nodes.
            if not next_state.done:
                # Build network input for inference.
                network_input = trajectory.stackObservations(
                    self.neural_net.net_args.observation_length,
                    state.observation)
                prior, value = self.neural_net.predict(network_input)

                # Inference statistics. Alternate value perspective due to adversary (model predicts for next player).
                self.Ps[s_next], self.Vs[
                    s_next] = prior, self.game.getLegalMoves(next_state)
                value = value if self.single_player else -value

        elif not self.Ssa[(s, a)].done:  ### EXPANSION
            trajectory.observations.append(
                state.observation
            )  # Build up an observation trajectory inside the tree
            value = self._search(self.Ssa[(s, a)], trajectory, path + (a, ))
            trajectory.observations.pop(
            )  # Clear tree observation trajectory when backing up

        ### BACKUP
        gk = self.Rsa[(
            s, a
        )] + self.args.gamma * value  # (Discounted) Value of the current node

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] +
                                gk) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = gk
            self.Nsa[(s, a)] = 1

        self.minmax.update(self.Qsa[(s, a)])
        self.Ns[s] += 1

        return gk if self.single_player else -gk