def _make_mini_mdp(pre_abs_state, post_abs_state, state_abstr, mdp): ''' Args: pre_abs_state (simple_rl.State) post_abs_state (simple_rl.State) state_abstr mdp (simple_rl.MDP) Returns: (simple_rl.MDP) ''' # Get init and terminal lower level states. ground_init_states = state_abstr.get_lower_states_in_abs_state( pre_abs_state) ground_reward_states = state_abstr.get_lower_states_in_abs_state( post_abs_state) rand_init_g_state = random.choice(ground_init_states) # R and T for Option Mini MDP. def _directed_option_reward_lambda(s, a, s_prime): # TODO: might need to sample here? original = s.is_terminal() s.set_terminal(s not in ground_init_states) s_prime = mdp.transition_func(s, a) s.set_terminal(original) # Returns non-zero reward iff the action transitions to a new abstract state. return int(s_prime in ground_reward_states and not s in ground_reward_states) def new_trans_func(s, a): original = s.is_terminal() s.set_terminal(s not in ground_init_states) s_prime = mdp.transition_func(s, a) s.set_terminal(original) return s_prime mini_mdp = MDP(actions=mdp.get_actions(), init_state=rand_init_g_state, transition_func=new_trans_func, reward_func=_directed_option_reward_lambda) return mini_mdp
def _prune_non_directed_options(options, state_pairs, state_abstr, mdp_distr): ''' Args: Options(list) state_pairs (list) state_abstr (StateAbstraction) mdp_distr (MDPDistribution) Returns: (list of Options) Summary: Removes redundant options. That is, if o_1 goes from s_A1 to s_A2, and o_2 goes from s_A1 *through s_A2 to s_A3, then we get rid of o_2. ''' good_options = set([]) bad_options = set([]) transition_func = mdp_distr.get_all_mdps()[0].get_transition_func() # For each option we created, we'll check overlap. for i, o in enumerate(options): print "\t Option", i + 1, "of", len(options) pre_abs_state, post_abs_state = state_pairs[i] # Get init and terminal lower level states. ground_init_states = state_abstr.get_lower_states_in_abs_state( pre_abs_state) ground_term_states = state_abstr.get_lower_states_in_abs_state( post_abs_state) rand_init_g_state = random.choice(ground_init_states) # R and T for Option Mini MDP. def _directed_option_reward_lambda(s, a): s_prime = transition_func(s, a) return int(s_prime in ground_term_states and not s in ground_term_states) def new_trans_func(s, a): original = s.is_terminal() s.set_terminal(s in ground_term_states) s_prime = transition_func(s, a) # print s, s_prime, s.is_terminal(), s_prime.is_terminal(), pre_abs_state, post_abs_state, s == s_prime s.set_terminal(original) return s_prime if pre_abs_state == post_abs_state: # Self looping option. mini_mdp_init_states = defaultdict(list) # Self loop. Make an option per goal in the cluster. goal_mdps = [] goal_state_action_pairs = defaultdict(list) for i, mdp in enumerate(mdp_distr.get_all_mdps()): add = False # Check if there is a goal for this MDP in one of the ground states. for s_g in ground_term_states: for a in mdp.get_actions(): if mdp.get_reward_func( )(s_g, a) > 0.0 and a not in goal_state_action_pairs[s_g]: goal_state_action_pairs[s_g].append(a) if isinstance(mdp, GridWorldMDP): goals = tuple(mdp.get_goal_locs()) else: goals = tuple(s_g) mini_mdp_init_states[goals].append(s_g) add = True if add: goal_mdps.append(mdp) # For each goal. for goal_mdp in goal_mdps: def goal_new_trans_func(s, a): original = s.is_terminal() s.set_terminal(s not in ground_term_states or original) s_prime = goal_mdp.get_transition_func()(s, a) s.set_terminal(original) return s_prime if isinstance(goal_mdp, GridWorldMDP): cluster_init_state = random.choice( mini_mdp_init_states[tuple(goal_mdp.get_goal_locs())]) else: cluster_init_state = random.choice(ground_init_states) # Make a new MDP. mini_mdp = MDP(actions=goal_mdp.get_actions(), init_state=cluster_init_state, transition_func=goal_new_trans_func, reward_func=goal_mdp.get_reward_func()) o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp) # Make new option. new_option = Option(o.init_predicate, o.term_predicate, o_policy) new_option.set_name(str(ground_init_states[0]) + "-sl") good_options.add(new_option) continue else: # This is a non-self looping option. mini_mdp = MDP(actions=mdp_distr.get_actions(), init_state=rand_init_g_state, transition_func=new_trans_func, reward_func=_directed_option_reward_lambda) o_policy, mini_mdp_vi = _make_mini_mdp_option_policy(mini_mdp) # Compute overlap w.r.t. plans from each state. for init_g_state in ground_init_states: # Prune overlapping ones. plan, state_seq = mini_mdp_vi.plan(init_g_state) opt_name = str(ground_init_states[0]) + "-" + str( ground_term_states[0]) o.set_name(opt_name) options[i] = o if not _check_overlap(o, state_seq, options, bad_options): # Give the option the new directed policy and name. o.set_policy(o_policy) good_options.add(o) break else: # The option overlaps, don't include it. bad_options.add(o) return good_options