Exemple #1
0
    def search(self, n_mcts, Env, mcts_env, H=30):
        ''' Perform the MCTS search from the root '''
        if self.root is None:
            # initialize new root
            self.root = ThompsonSamplingState(
                self.root_index,
                r=0.0,
                terminal=False,
                parent_action=None,
                na=self.na,
                model=self)  #, signature=mcts_env.get_signature()
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(
                Env)  # for Atari: snapshot the root at the beginning

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)

            depth = 0
            mcts_env.seed()
            self.search_iteration(mcts_env, state, depth, H)
    def search(self, n_mcts, c, env, mcts_env):
        """
        Perform the MCTS search from the root
        """
        if self.root is None:
            self.root = State(
                self.root_index,
                r=0.0,
                terminal=False,
                parent_action=None,
                na=self.na,
                bootstrap_last_state_value=self.bootstrap_last_state_value,
                model=self.model)  # initialize new root
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(env)
        if is_atari:
            snapshot = copy_atari_state(
                env)  # for Atari: snapshot the root at the beginning

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)

            while not state.terminal:
                action = state.select(c=c)
                s1, r, t, _ = mcts_env.step(action.index)
                if hasattr(action, 'child_state'):
                    state = action.child_state  # select
                    continue
                else:
                    state = action.add_child_state(s1, r, t,
                                                   self.model)  # expand
                    break

            # Back-up
            r = state.V
            while state.parent_action is not None:  # loop back-up until root is reached
                r = state.r + self.gamma * r
                action = state.parent_action
                action.update(r)
                state = action.parent_state
                state.update()
Exemple #3
0
    def search(self, n_mcts, c, Env, mcts_env, budget, max_depth=200):
        """ Perform the MCTS search from the root """
        if self.root is None:
            # initialize new root
            self.root = State(self.root_index, r=0.0, terminal=False, parent_action=None, na=self.na, env=mcts_env, budget=budget)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(Env)  # for Atari: snapshot the root at the beginning

        while budget > 0:
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            st = 0
            while not state.terminal:
                bias = c * self.gamma ** st / (1 - self.gamma) if self.depth_based_bias else c
                action = state.select(c=bias)
                st += 1
                s1, r, t, _ = mcts_env.step(action.index)
                if hasattr(action, 'child_state'):
                    state = action.child_state  # select
                    if state.terminal:
                        budget -= 1
                    continue
                else:
                    state, budget = action.add_child_state(s1, r, t, budget, env=mcts_env, max_depth=max_depth-st)  # expand
                    break

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
    def search(self, n_mcts, c, Env, mcts_env, budget, max_depth=200):
        ''' Perform the MCTS search from the root '''
        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(Env)  # for Atari: snapshot the root at the beginning
        else:
            mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
        # else:
        #     restore_atari_state(mcts_env, snapshot)

        # Check that the environment has been copied correctly
        try:
            sig1 = mcts_env.get_signature()
            sig2 = Env.get_signature()
            if sig1.keys() != sig2.keys():
                raise AssertionError
            if not all(np.array_equal(sig1[key], sig2[key]) for key in sig1):
                raise AssertionError
        except AssertionError:
            print("Something wrong while copying the environment")
            sig1 = mcts_env.get_signature()
            sig2 = Env.get_signature()
            print(sig1.keys(), sig2.keys())
            exit()

        if self.root is None:
            # initialize new root
            self.root = StochasticState(self.root_index, r=0.0, terminal=False, parent_action=None, na=self.na,
                                        signature=Env.get_signature(), env=mcts_env, budget=budget)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        while budget > 0:
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            mcts_env.seed()
            st = 0
            while not state.terminal:
                bias = c * self.gamma ** st / (1 - self.gamma) if self.depth_based_bias else c
                action = state.select(c=bias)
                st += 1
                k = np.ceil(self.beta * action.n ** self.alpha)
                if k >= action.n_children:
                    s1, r, t, _ = mcts_env.step(action.index)
                    # if action.index == 0 and not np.array_equal(s1.flatten(), action.parent_state.index.flatten()):
                    #     print("WTF")
                    budget -= 1
                    if action.get_state_ind(s1) != -1:
                        state = action.child_states[action.get_state_ind(s1)]  # select
                        state.r = r
                    else:
                        state, budget = action.add_child_state(s1, r, t, mcts_env.get_signature(), budget, env=mcts_env,
                                                               max_depth=max_depth - st)  # expand
                        break
                else:
                    state = action.sample_state()
                    mcts_env.set_signature(state.signature)
                    if state.terminal:
                        budget -= 1

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
    def search(self, n_mcts, c, Env, mcts_env, max_depth=200):
        ''' Perform the MCTS search from the root '''
        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(
                Env)  # for Atari: snapshot the root at the beginning
        else:
            mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
        # else:
        #     restore_atari_state(mcts_env, snapshot)
        if mcts_env._state != Env._state:
            print("Copying went wrong")
        if self.root is None:
            # initialize new root
            self.root = StochasticState(self.root_index,
                                        r=0.0,
                                        terminal=False,
                                        parent_action=None,
                                        na=self.na,
                                        model=self.model,
                                        signature=Env.get_signature(),
                                        max_depth=max_depth)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            # obs1 = mcts_env._get_obs().flatten()
            # obs2 = Env._get_obs().flatten()
            # if not np.array_equal(obs1, obs2):
            #     print("HOLDUP")
            mcts_env.seed()
            while not state.terminal:
                # obs = mcts_env._get_obs().flatten()
                # flattened_State = state.index.flatten()
                # if not np.array_equal(flattened_State, obs):
                #     print("WHATTTTTT")
                action = state.select(c=c)
                k = np.ceil(c * action.n**self.alpha)
                if k >= action.n_children:
                    s1, r, t, _ = mcts_env.step(action.index)
                    # if action.index == 0 and not np.array_equal(s1.flatten(), action.parent_state.index.flatten()):
                    #     print("WTF")
                    if action.get_state_ind(s1) != -1:
                        state = action.child_states[action.get_state_ind(
                            s1)]  # select
                        state.r = r
                    else:
                        # if action.index == 0 and len(action.child_states) > 0:
                        #     print("Error")
                        state = action.add_child_state(
                            s1, r, t, self.model,
                            mcts_env.get_signature())  # expand
                        break
                else:
                    state = action.sample_state()
                    mcts_env.set_signature(state.signature)
                    # obs = mcts_env._get_obs().flatten()
                    # flattened_State = state.index.flatten()
                    # if not np.array_equal(flattened_State, obs):
                    #     print("WHATTTTTT")

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
    def search(self,
               n_mcts,
               c,
               Env,
               mcts_env,
               budget,
               max_depth=200,
               fixed_depth=True):
        """ Perform the MCTS search from the root """
        Envs = None
        if not self.sampler:
            Envs = [copy.deepcopy(Env) for _ in range(self.n_particles)]

        if self.root is None:
            # initialize new root with many equal particles

            signature = Env.get_signature()

            box = None
            to_box = getattr(self, "index_to_box", None)
            if callable(to_box):
                box = Env.index_to_box(signature["state"])

            particles = [
                Particle(state=signature,
                         seed=random.randint(0, 1e7),
                         reward=0,
                         terminal=False,
                         info=box) for _ in range(self.n_particles)
            ]
            self.root = State(parent_action=None,
                              na=self.na,
                              envs=Envs,
                              particles=particles,
                              sampler=self.sampler,
                              root=True,
                              budget=budget)
        else:
            self.root.parent_action = None  # continue from current root
            particles = self.root.particles

        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(Env)
        if is_atari:
            raise NotImplementedError
            snapshot = copy_atari_state(
                Env)  # for Atari: snapshot the root at the beginning

        while budget > 0:
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_envs = None
                if not self.sampler:
                    mcts_envs = [
                        copy.deepcopy(Env) for i in range(self.n_particles)
                    ]  # copy original Env to rollout from
            else:
                raise NotImplementedError
                restore_atari_state(mcts_env, snapshot)
            st = 0
            while not state.terminal:
                action = state.select(c=c)
                st += 1
                # s1, r, t, _ = mcts_env.step(action.index)
                if hasattr(action, 'child_state'):
                    state = action.child_state  # select
                    if state.terminal:
                        budget -= len(state.particles)
                    continue
                else:
                    rollout_depth = max_depth if fixed_depth else max_depth - st
                    state, budget = action.add_child_state(
                        state, mcts_envs, budget, self.sampler,
                        rollout_depth)  # expand
                    break

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()