Example #1
0
def expand_state(s, h_v, env, explicit_graph, goal, A, succs_cache=None):
    if check_goal(s, goal):
        raise ValueError(
            f'State {s} can\'t be expanded because it is a goal state')

    neighbour_states_dict = {}
    neighbour_states = []
    i = 0
    for a in A:
        if succs_cache and (s, a) in succs_cache:
            succs = succs_cache[(s, a)]
        else:
            succs = get_successor_states_check_exception(s, a, env.domain)
        for s_, p in succs.items():
            if s_ not in neighbour_states_dict:
                neighbour_states_dict[s_] = i
                i += 1
                neighbour_states.append({'state': s_, 'A': {a: p}})
            else:
                neighbour_states[neighbour_states_dict[s_]]['A'][a] = p

    unexpanded_neighbours = filter(
        lambda _s: (not _s['state'] in explicit_graph) or
        (not explicit_graph[_s['state']]['expanded']), neighbour_states)

    # Add new empty states to 's' adjacency list
    new_explicit_graph = copy(explicit_graph)

    new_explicit_graph[s]["Adj"].extend(neighbour_states)

    for n in unexpanded_neighbours:
        if n['state'] != s:
            is_goal = check_goal(n['state'], goal)
            h_v_ = 0 if is_goal else h_v(n['state'])
            new_explicit_graph[n['state']] = {
                "value": h_v_,
                "solved": False,
                "pi": None,
                "expanded": False,
                "Q_v": {a: h_v_
                        for a in A},
                "Adj": []
            }

    new_explicit_graph[s]['expanded'] = True

    return new_explicit_graph
Example #2
0
 def visit(s, i, d, low):
     nonlocal explicit_graph, A, goal, n_updates_
     is_goal = check_goal(s, goal)
     if not is_goal and not explicit_graph[s]['expanded']:
         explicit_graph = expand_state(s,
                                       h_v,
                                       env,
                                       explicit_graph,
                                       goal,
                                       A,
                                       succs_cache=succs_cache)
     if not is_goal:
         # run bellman backup
         explicit_graph = backup_bellman(explicit_graph, A, s, goal,
                                         gamma, C)
         n_updates_ += 1
Example #3
0
def vi(S, succ_states, A, V_i, G_i, goal, env, gamma, epsilon):

    V = np.zeros(len(V_i))
    P = np.zeros(len(V_i))
    pi = np.full(len(V_i), None)
    print(len(S), len(V_i), len(G_i), len(P))
    print(G_i)
    P[G_i] = 1

    i = 0
    diff = np.inf
    while True:
        print('Iteration', i, diff)
        V_ = np.copy(V)
        P_ = np.copy(P)

        for s in S:
            if check_goal(s, goal):
                continue
            Q = np.zeros(len(A))
            Q_p = np.zeros(len(A))
            cost = 1
            for i_a, a in enumerate(A):
                succ = succ_states[s, a]

                probs = np.fromiter(iter(succ.values()), dtype=float)
                succ_i = [V_i[succ_s] for succ_s in succ_states[s, a]]
                Q[i_a] = cost + np.dot(probs, gamma * V_[succ_i])
                Q_p[i_a] = np.dot(probs, P_[succ_i])
            V[V_i[s]] = np.min(Q)
            P[V_i[s]] = np.max(Q_p)
            pi[V_i[s]] = A[np.argmin(Q)]

        diff = np.linalg.norm(V_ - V, np.inf)
        if diff < epsilon:
            break
        i += 1
    return V, pi
Example #4
0
 def _is_goal_reached(self, state):
     """
     Check if the terminal condition is met, i.e., the goal is reached.
     """
     return check_goal(state, self._goal)
Example #5
0
env.fix_problem_index(problem_index)
problem = env.problems[problem_index]
goal = problem.goal
prob_objects = frozenset(problem.objects)

obs, _ = env.reset()
A = list(sorted(env.action_space.all_ground_literals(obs, valid_only=False)))

print(' calculating list of states...')
reach = get_all_reachable(obs, A, env)
S = list(sorted([s for s in reach]))
print('Number of states:', len(S))

print('done')
V_i = {s: i for i, s in enumerate(S)}
G_i = [V_i[s] for s in V_i if check_goal(s, goal)]

print('obtaining optimal policy')
succ_states = {s: {} for s in reach}
for s in reach:
    for a in A:
        succ_states[s, a] = reach[s][a]
V, pi = vi(S, succ_states, A, V_i, G_i, goal, env, args.gamma, args.epsilon)
pi_func = lambda s: pi[V_i[s]]

n_episodes = 1000

plot = False
if args.plot_stats:
    print('running episodes with optimal policy')
    steps1 = []
Example #6
0
def get_unexpanded_states(goal, explicit_graph, bpsg):
    return list(
        filter(
            lambda x: (x not in explicit_graph) or
            (not explicit_graph[x]["expanded"] and not check_goal(x, goal)),
            bpsg.keys()))
Example #7
0
 def C(s, a):
     return 0 if check_goal(s, goal) else 1
Example #8
0
def ilao(s0, h_v, goal, A, gamma, env, epsilon=1e-3, succs_cache=None):

    bpsg = {s0: {"Adj": []}}
    explicit_graph = {}
    succs_cache = {} if succs_cache == None else succs_cache

    explicit_graph[s0] = {
        "value": h_v(s0),
        "solved": False,
        "expanded": False,
        "pi": None,
        "Q_v": {a: h_v(s0)
                for a in A},
        "Adj": []
    }

    def C(s, a):
        return 0 if check_goal(s, goal) else 1

    i = 1
    unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)
    n_updates = 0
    explicit_graph_cur_size = 1
    while True:
        while len(unexpanded) > 0:
            print("Iteration", i)
            print(len(unexpanded), "unexpanded states")

            n_updates_ = 0

            def visit(s, i, d, low):
                nonlocal explicit_graph, A, goal, n_updates_
                is_goal = check_goal(s, goal)
                if not is_goal and not explicit_graph[s]['expanded']:
                    explicit_graph = expand_state(s,
                                                  h_v,
                                                  env,
                                                  explicit_graph,
                                                  goal,
                                                  A,
                                                  succs_cache=succs_cache)
                if not is_goal:
                    # run bellman backup
                    explicit_graph = backup_bellman(explicit_graph, A, s, goal,
                                                    gamma, C)
                    n_updates_ += 1

            dfs(bpsg, on_visit=visit)

            assert len(explicit_graph) >= explicit_graph_cur_size

            explicit_graph_cur_size = len(explicit_graph)
            print("explicit graph size:", explicit_graph_cur_size)
            print(f"Finished value iteration in {n_updates_} updates")
            n_updates += n_updates_
            bpsg = update_partial_solution(s0, bpsg, explicit_graph)

            unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)
            i += 1
        bpsg_states = [s_ for s_ in bpsg.keys() if not check_goal(s_, goal)]
        print(f"Will start convergence test for bpsg with {len(bpsg)} states")
        explicit_graph, converged, changed, n_updates_ = value_iteration(
            explicit_graph,
            bpsg,
            A,
            bpsg_states,
            goal,
            gamma,
            C,
            epsilon=epsilon,
            convergence_test=True)
        n_updates += n_updates_
        print(f"Finished convergence test in {n_updates_} updates")

        bpsg = update_partial_solution(s0, bpsg, explicit_graph)

        unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)

        if changed:
            continue

        if converged and len(unexpanded) == 0:
            break
    for s_ in bpsg:
        explicit_graph[s_]['solved'] = True
    return explicit_graph, bpsg, n_updates
Example #9
0
def lao(s0, h_v, goal, A, gamma, env, epsilon=1e-3):
    bpsg = {s0: {"Adj": []}}
    explicit_graph = {}

    explicit_graph[s0] = {
        "value": h_v(s0),
        "solved": False,
        "expanded": False,
        "pi": None,
        "Q_v": {a: h_v(s0)
                for a in A},
        "Adj": []
    }

    def C(s, a):
        return 0 if check_goal(s, goal) else 1

    i = 1

    unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)
    n_updates = 0
    explicit_graph_cur_size = 1
    while True:
        while len(unexpanded) > 0:
            s = unexpanded[0]
            print("Iteration", i)
            print("Will expand", len(unexpanded), "states")
            Z = set()
            for s in unexpanded:
                explicit_graph = expand_state(s, h_v, env, explicit_graph,
                                              goal, A)
                Z.add(s)
                Z.update(find_ancestors(s, explicit_graph, best=True))

            assert len(explicit_graph) >= explicit_graph_cur_size
            explicit_graph_cur_size = len(explicit_graph)
            print("explicit graph size:", explicit_graph_cur_size)
            print("Z size:", len(Z))
            explicit_graph, _, __, n_updates_ = value_iteration(
                explicit_graph, bpsg, A, Z, goal, gamma, C, epsilon=epsilon)
            print(f"Finished value iteration in {n_updates_} updates")
            n_updates += n_updates_
            bpsg = update_partial_solution(s0, bpsg, explicit_graph)
            unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)
            i += 1
        bpsg_states = [s_ for s_ in bpsg.keys() if not check_goal(s_, goal)]
        print(f"Will start convergence test for bpsg with {len(bpsg)} states")
        explicit_graph, converged, changed, n_updates_ = value_iteration(
            explicit_graph,
            bpsg,
            A,
            bpsg_states,
            goal,
            gamma,
            C,
            epsilon=epsilon,
            convergence_test=True)
        print(f"Finished convergence test in {n_updates_} updates")
        n_updates += n_updates_

        bpsg = update_partial_solution(s0, bpsg, explicit_graph)
        unexpanded = get_unexpanded_states(goal, explicit_graph, bpsg)

        if converged and len(unexpanded) == 0:
            break
    return explicit_graph, bpsg, n_updates