コード例 #1
0
def _make_mini_mdp(pre_abs_state, post_abs_state, state_abstr, mdp):
    '''
    Args:
        pre_abs_state (simple_rl.State)
        post_abs_state (simple_rl.State)
        state_abstr
        mdp (simple_rl.MDP)

    Returns:
        (simple_rl.MDP)
    '''

    # Get init and terminal lower level states.
    ground_init_states = state_abstr.get_lower_states_in_abs_state(
        pre_abs_state)
    ground_reward_states = state_abstr.get_lower_states_in_abs_state(
        post_abs_state)
    rand_init_g_state = random.choice(ground_init_states)

    # R and T for Option Mini MDP.
    def _directed_option_reward_lambda(s, a, s_prime):
        # TODO: might need to sample here?
        original = s.is_terminal()
        s.set_terminal(s not in ground_init_states)
        s_prime = mdp.transition_func(s, a)
        s.set_terminal(original)

        # Returns non-zero reward iff the action transitions to a new abstract state.
        return int(s_prime in ground_reward_states
                   and not s in ground_reward_states)

    def new_trans_func(s, a):
        original = s.is_terminal()
        s.set_terminal(s not in ground_init_states)
        s_prime = mdp.transition_func(s, a)
        s.set_terminal(original)
        return s_prime

    mini_mdp = MDP(actions=mdp.get_actions(),
                   init_state=rand_init_g_state,
                   transition_func=new_trans_func,
                   reward_func=_directed_option_reward_lambda)

    return mini_mdp
コード例 #2
0
def _prune_non_directed_options(options, state_pairs, state_abstr, mdp_distr):
    '''
    Args:
        Options(list)
        state_pairs (list)
        state_abstr (StateAbstraction)
        mdp_distr (MDPDistribution)

    Returns:
        (list of Options)

    Summary:
        Removes redundant options. That is, if o_1 goes from s_A1 to s_A2, and
        o_2 goes from s_A1 *through s_A2 to s_A3, then we get rid of o_2.
    '''
    good_options = set([])
    bad_options = set([])
    transition_func = mdp_distr.get_all_mdps()[0].get_transition_func()

    # For each option we created, we'll check overlap.
    for i, o in enumerate(options):
        print "\t  Option", i + 1, "of", len(options)
        pre_abs_state, post_abs_state = state_pairs[i]

        # Get init and terminal lower level states.
        ground_init_states = state_abstr.get_lower_states_in_abs_state(
            pre_abs_state)
        ground_term_states = state_abstr.get_lower_states_in_abs_state(
            post_abs_state)
        rand_init_g_state = random.choice(ground_init_states)

        # R and T for Option Mini MDP.
        def _directed_option_reward_lambda(s, a):
            s_prime = transition_func(s, a)
            return int(s_prime in ground_term_states
                       and not s in ground_term_states)

        def new_trans_func(s, a):
            original = s.is_terminal()
            s.set_terminal(s in ground_term_states)
            s_prime = transition_func(s, a)
            # print s, s_prime, s.is_terminal(), s_prime.is_terminal(), pre_abs_state, post_abs_state, s == s_prime
            s.set_terminal(original)
            return s_prime

        if pre_abs_state == post_abs_state:
            # Self looping option.
            mini_mdp_init_states = defaultdict(list)

            # Self loop. Make an option per goal in the cluster.
            goal_mdps = []
            goal_state_action_pairs = defaultdict(list)
            for i, mdp in enumerate(mdp_distr.get_all_mdps()):
                add = False

                # Check if there is a goal for this MDP in one of the ground states.
                for s_g in ground_term_states:
                    for a in mdp.get_actions():
                        if mdp.get_reward_func(
                        )(s_g,
                          a) > 0.0 and a not in goal_state_action_pairs[s_g]:
                            goal_state_action_pairs[s_g].append(a)
                            if isinstance(mdp, GridWorldMDP):
                                goals = tuple(mdp.get_goal_locs())
                            else:
                                goals = tuple(s_g)
                            mini_mdp_init_states[goals].append(s_g)
                            add = True

                if add:
                    goal_mdps.append(mdp)

            # For each goal.
            for goal_mdp in goal_mdps:

                def goal_new_trans_func(s, a):
                    original = s.is_terminal()
                    s.set_terminal(s not in ground_term_states or original)
                    s_prime = goal_mdp.get_transition_func()(s, a)
                    s.set_terminal(original)
                    return s_prime

                if isinstance(goal_mdp, GridWorldMDP):
                    cluster_init_state = random.choice(
                        mini_mdp_init_states[tuple(goal_mdp.get_goal_locs())])
                else:
                    cluster_init_state = random.choice(ground_init_states)

                # Make a new MDP.
                mini_mdp = MDP(actions=goal_mdp.get_actions(),
                               init_state=cluster_init_state,
                               transition_func=goal_new_trans_func,
                               reward_func=goal_mdp.get_reward_func())

                o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp)

                # Make new option.
                new_option = Option(o.init_predicate, o.term_predicate,
                                    o_policy)
                new_option.set_name(str(ground_init_states[0]) + "-sl")
                good_options.add(new_option)

            continue
        else:
            # This is a non-self looping option.
            mini_mdp = MDP(actions=mdp_distr.get_actions(),
                           init_state=rand_init_g_state,
                           transition_func=new_trans_func,
                           reward_func=_directed_option_reward_lambda)

            o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp)
            # Compute overlap w.r.t. plans from each state.
            for init_g_state in ground_init_states:
                # Prune overlapping ones.
                plan, state_seq = mini_mdp_vi.plan(init_g_state)
                opt_name = str(ground_init_states[0]) + "-" + str(
                    ground_term_states[0])
                o.set_name(opt_name)
                options[i] = o

                if not _check_overlap(o, state_seq, options, bad_options):
                    # Give the option the new directed policy and name.
                    o.set_policy(o_policy)
                    good_options.add(o)
                    break
                else:
                    # The option overlaps, don't include it.
                    bad_options.add(o)

    return good_options