示例#1
0
    def pick_action(self, env, n: OR_Node, action, history):
        history.append(n.children[action])
        next_state, reward, terminal, _ = env.step(action)
        self.sim_calls += 1
        next_state = np.reshape(next_state,
                                [1, np.prod(env.observation_space.shape)])
        succ = OR_Node(next_state, n.d + 1, terminal)
        if n.children[action].update(reward, succ):
            if self.tabulate_state_visits:
                tupOfState = tuple(succ.state[0].tolist())
                try:
                    succ.num_visits = self.stateCount[tupOfState]
                except:
                    succ.num_visits = 0
            else:
                succ.num_visits = 0  # if we get a new successor
            succ.r = reward
            self._exp_graph.update(n, action, reward, succ)
            node = succ
            node.v = 0
            node.num_rollouts = 0
        else:
            for child in n.children[action].children:
                succ1, reward1 = child
                if reward1 == reward and succ1 == succ:
                    #wizluk.logger.debug("the state exists")
                    node = succ1
        history.append(node)
        n.children[action].visited = True

        self.max_depth = max(self.max_depth, node.d)
        return node, history
    def bestChild(self, env, n: OR_Node, history):
        assert (len(n.children) > 0)
        L = [
            float('inf')
            if n.children[k].num_visits == 0 else n.children[k].Q + self._C *
            np.sqrt(2 * np.log(n.num_visits / n.children[k].num_visits))
            for k in n.children.keys()
        ]
        selected = list(n.children.keys())[np.argmax(L)]

        history.append(n.children[selected])
        if self._atari == "True" and len(n.children[selected].children
                                         ) != 0 and self._caching != "None":
            elapsed_steps = env._elapsed_steps
            envuw = env.unwrapped
            for node, reward in n.children[selected].children:
                if hasattr(
                        node,
                        'restoreStateFrom') and node.restoreState is not None:
                    break
            wasRestored = False
            if hasattr(node,
                       'restoreStateFrom') and node.restoreState is not None:
                if node.restoreStateFrom != self.get_action_call and self._caching != "Full":  #State is not from this get action call therefore for partial caching don't restore
                    node.restoreState = None
                else:
                    env.unwrapped.restore_full_state(node.restoreState)
                    env._elapsed_steps = elapsed_steps + 1
                    succ = node
                    wasRestored = True

            if not wasRestored:
                assert (False)
        else:
            next_state, reward, terminal, _ = env.step(selected)
            self.sim_calls += 1
            if self._atari != "True":
                next_state = np.reshape(
                    next_state, [1, np.prod(env.observation_space.shape)])
            elif self._caching != "None":  # If atari and caching is on
                assert (False)
            succ = OR_Node(next_state, n.d + 1, terminal)
            if n.children[selected].update(reward, succ):
                if self.tabulate_state_visits:
                    tupOfState = tuple(succ.state[0].tolist())
                    try:
                        succ.num_visits = self.stateCount[tupOfState]
                    except:
                        succ.num_visits = 0
                else:
                    succ.num_visits = 0  # if we get a new successor
                succ.r = reward
            else:
                foundChild = False
                for child in n.children[selected].children:
                    succ1, reward1 = child
                    if reward1 == reward and succ1 == succ:
                        succ = succ1
                        foundChild = True
                assert (foundChild)

        history.append(succ)
        n.children[selected].visited = True

        self.max_depth = max(self.max_depth, succ.d)
        return succ, history
    def pick_random_unvisited_child(self, env, n: OR_Node, history):
        candidates = [
            k for k in n.children.keys() if not n.children[k].visited
        ]
        selected = np.random.choice(candidates)
        history.append(n.children[selected])
        if self._atari == "True" and len(n.children[selected].children
                                         ) != 0 and self._caching != "None":
            elapsed_steps = env._elapsed_steps
            wasRestored = False
            for node, reward in n.children[selected].children:
                if hasattr(
                        node,
                        'restoreStateFrom') and node.restoreState is not None:
                    break
            if hasattr(node,
                       'restoreStateFrom') and node.restoreState is not None:
                if node.restoreStateFrom != self.get_action_call and self._caching != "Full":  #State is not from this get action call therefore for partial caching don't restore
                    node.restoreState = None
                else:
                    env.unwrapped.restore_full_state(node.restoreState)
                    env._elapsed_steps = elapsed_steps + 1
                    succ = node
                    wasRestored = True

            if not wasRestored:
                next_state, reward, terminal, _ = env.step(selected)
                self.sim_calls += 1
                if np.array_equal(
                        next_state, node._state
                ) and reward == node.r and terminal == node.terminal:
                    succ = node
                else:
                    succ = copy.deepcopy(node)
                    succ.r = reward
                    succ._state = copy.deepcopy(next_state)
                    succ.terminal = terminal
                    n.children[selected].children.add((succ, reward))
                succ.restoreState = env.unwrapped.clone_full_state()
                succ.restoreStateFrom = self.get_action_call
        else:
            next_state, reward, terminal, _ = env.step(selected)
            self.sim_calls += 1
            if self._atari != "True":
                next_state = np.reshape(
                    next_state, [1, np.prod(env.observation_space.shape)])
            succ = OR_Node(next_state, n.d + 1, terminal)
            if n.children[selected].update(reward, succ):
                if self.tabulate_state_visits:
                    tupOfState = tuple(succ.state[0].tolist())
                    try:
                        succ.num_visits = self.stateCount[tupOfState]
                    except:
                        succ.num_visits = 0
                else:
                    succ.num_visits = 0  # if we get a new successor
                succ.r = reward
            else:
                foundChild = False
                for child in n.children[selected].children:
                    succ1, reward1 = child
                    if reward1 == reward and succ1 == succ:
                        succ = succ1
                        foundChild = True
                assert (foundChild)
            if self._atari == "True" and self._caching != "None":
                succ.restoreState = env.unwrapped.clone_full_state()
                succ.restoreStateFrom = self.get_action_call
        history.append(succ)

        assert self.isInChildrenOnce(n.children[selected], succ)
        n.children[selected].visited = True

        self.max_depth = max(self.max_depth, succ.d)
        return succ, history