def make_fixed_rand_options(mdp, state_abstr):
    '''
    Args:
        mdp (simple_rl.MDP)
        state_abstr (simple_rl.StateAbstraction)

    Returns:
        (list)
    '''
    # Grab relevant states.
    abs_states = state_abstr.get_abs_states()
    g_start_state = mdp.get_init_state()

    # Compute all directed options that transition between abstract states.
    options = []
    state_pairs = {}
    placeholder_policy = lambda s : random.choice(mdp.get_actions(s))

    # For each s_{phi,1} s_{phi,2} pair.
    for s_a in abs_states:
        for s_a_prime in abs_states:
            if not(s_a == s_a_prime) and (s_a,s_a_prime) not in state_pairs.keys() and (s_a_prime, s_a) not in state_pairs.keys():
                # Make an option to transition between the two states.
                init_predicate = InListPredicate(ls=state_abstr.get_ground_states_in_abs_state(s_a))
                term_predicate = InListPredicate(ls=state_abstr.get_ground_states_in_abs_state(s_a_prime))
                
                o = Option(init_predicate=init_predicate,
                           term_predicate=term_predicate,
                           policy=placeholder_policy)

                options.append(o)
                state_pairs[(s_a, s_a_prime)] = 1    # Grab relevant states.
    abs_states = state_abstr.get_abs_states()
    g_start_state = mdp.get_init_state()

    # Compute all directed options that transition between abstract states.
    options = []
    state_pairs = {}
    placeholder_policy = lambda s : random.choice(mdp.get_actions(s))

    # For each s_{phi,1} s_{phi,2} pair.
    for s_a in abs_states:
        for s_a_prime in abs_states:
            if not(s_a == s_a_prime) and (s_a,s_a_prime) not in state_pairs.keys() and (s_a_prime, s_a) not in state_pairs.keys():
                # Make an option to transition between the two states.
                init_predicate = InListPredicate(ls=state_abstr.get_ground_states_in_abs_state(s_a))
                term_predicate = InListPredicate(ls=state_abstr.get_ground_states_in_abs_state(s_a_prime))
                
                o = Option(init_predicate=init_predicate,
                           term_predicate=term_predicate,
                           policy=placeholder_policy)

                options.append(o)
                state_pairs[(s_a, s_a_prime)] = 1
def find_eigenoptions(mdp, num_options=4, init_everywhere=False):
    delta = 0.001 # threshold for float point error
    
    # TODO: assume that the state-space is strongly connected.

    # Compute laplacian.
    A, state_to_id, id_to_state = get_transition_matrix(mdp)
    for n in range(A.shape[0]):
        if A[n][n] == 1:
            A[n][n] = 0 # Prune self-loops for the analysis            
    degrees = np.sum(A, axis=0)
    T = np.diag(degrees)
    Tngsqrt = np.diag(1.0 / np.sqrt(degrees))
    L = T - A
    normL = np.matmul(np.matmul(Tngsqrt, L), Tngsqrt)
    eigenvals, eigenvecs = np.linalg.eigh(normL)
    eigenoptions = []

    for i in range(0, num_options):
        # 1st eigenval is not useful
        maxnode = np.argwhere(eigenvecs[:,i] >= np.amax(eigenvecs[:, i]) - delta) + 1
        minnode = np.argwhere(eigenvecs[:,1] <= np.amin(eigenvecs[:, 1]) + delta) + 1

        # Make init/goal sets.
        init_set_nums = list(maxnode.flatten())
        init_set = [id_to_state[s - 1] for s in init_set_nums]
        goal_set_nums = list(minnode.flatten())
        goal_set = [id_to_state[s - 1] for s in goal_set_nums]

        # Define predicates.
        if init_everywhere:
            # Initiate everywhere.
            init_predicate = Predicate(lambda x:True)
        else:
            # Terminate everywhere
            init_predicate = InListPredicate(ls=init_set)
        term_predicate = InListPredicate(ls=goal_set)

        eigen_o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=make_option_policy(mdp, id_to_state.values(), goal_set))

        eigenoptions.append(eigen_o)


        # TODO: translate to an Option object.

    return eigenoptions[0:num_options]
def make_goal_based_options(mdp_distr):
    '''
    Args:
        mdp_distr (MDPDistribution)

    Returns:
        (list): Contains Option instances.
    '''

    goal_list = set([])
    for mdp in mdp_distr.get_all_mdps():
        vi = ValueIteration(mdp)
        state_space = vi.get_states()
        for s in state_space:
            if s.is_terminal():
                goal_list.add(s)

    options = set([])
    for mdp in mdp_distr.get_all_mdps():

        init_predicate = Predicate(func=lambda x: True)
        term_predicate = InListPredicate(ls=goal_list)
        o = Option(init_predicate=init_predicate,
                   term_predicate=term_predicate,
                   policy=_make_mini_mdp_option_policy(mdp),
                   term_prob=0.0)
        options.add(o)

    return options
def compute_omega_given_m_phi(mdp, state_abstr):
    '''
    Args:
        mdp (simple_rl.MDP)
        phi (simple_rl.abstraction.StateAbstraction)

    Returns:
        omega (simple_rl.abstraction.ActionAbstraction)
    '''
    # Grab relevant states.
    abs_states = state_abstr.get_abs_states()
    g_start_state = mdp.get_init_state()

    # Compute all directed options that transition between abstract states.
    options = []
    state_pairs = {}
    placeholder_policy = lambda s: random.choice(mdp.get_actions(s))

    # For each s_{phi,1} s_{phi,2} pair.
    for s_a in abs_states:
        for s_a_prime in abs_states:
            if not (s_a == s_a_prime) and (
                    s_a, s_a_prime) not in state_pairs.keys() and (
                        s_a_prime, s_a) not in state_pairs.keys():
                # Make an option to transition between the two states.
                init_predicate = InListPredicate(
                    ls=state_abstr.get_ground_states_in_abs_state(s_a))
                term_predicate = InListPredicate(
                    ls=state_abstr.get_ground_states_in_abs_state(s_a_prime))

                o = Option(init_predicate=init_predicate,
                           term_predicate=term_predicate,
                           policy=placeholder_policy)

                options.append(o)
                state_pairs[(s_a, s_a_prime)] = 1

    # Prune.
    pruned_option_set = ah._prune_redundant_options(options,
                                                    state_pairs.keys(),
                                                    state_abstr, mdp)

    return ActionAbstraction(options=pruned_option_set,
                             on_failure="primitives")
def make_point_options(mdp, pairs, policy='vi'):
    '''
    Args:
        mdp
        pairs: a list of pairs. Each pair is a list containing init set and term set.

    Returns:
        (list): Contains Option instances.
    '''

    options = set([])
    for pair in pairs:
        init = pair[0]
        term = pair[1]
        if type(init) is not list:
            init = [init]
        if type(term) is not list:
            term = [term]
        # init_predicate = Predicate(func=lambda x: True)
        init_predicate = InListPredicate(ls=init)
        term_predicate = InListPredicate(ls=term)

        if policy == 'vi':
            o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=_make_mini_mdp_option_policy(mdp, n_iters=100),
                       term_prob=0.0)
        elif policy == 'dqn':
            o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=_make_dqn_option_policy(mdp, term[0]),
                       term_prob=0.0)
        else:
            assert (False)
        options.add(o)

    return options
def make_subgoal_options(mdp,
                         goal_list,
                         init_space=None,
                         vectors=None,
                         n_trajs=100,
                         n_steps=100,
                         classifier='list',
                         policy='vi'):
    '''
    Args:
        mdp
        goal_list: set of lists.
        init_space: list of states.

    Returns:
        (list): Contains Option instances.
    '''

    if classifier == 'list':
        init_predicate = InListPredicate(ls=init_space)
    elif classifier == 'svc':
        init_predicate = ClassifierPredicate(init_space)
    else:
        print('Error: unknown predicate for init condition:', classifier)
        assert (False)

    options = set([])
    # print('init_space=', init_space)
    for i, gs in enumerate(goal_list):

        # print('goals=', g)
        # print('type(g)=', g)
        # init_predicate = Predicate(func=lambda x: True)
        # init_predicate = InListPredicate(ls=init_space)

        ############################
        # Termination set is set to (the subgoal state) + (unknown region).
        ############################
        term = copy(init_space)

        # print('term=', term, type(term))
        # print('type(term)=', type(term))
        # print('gs=', gs)
        for g in gs:
            # print('g=', g, type(g))
            if g in term:
                term.remove(g)

        if classifier == 'list':
            term_predicate = InListPredicate(ls=term, true_if_in=False)
        elif classifier == 'svc':
            term_predicate = ClassifierPredicate(term, true_if_in=False)
        else:
            print('Error: unknown predicate for init condition:', classifier)
            assert (False)

        if policy == 'vi':
            vector = dict()
            for g in gs:
                vector[hash(g)] = 1
            mdp_ = IntrinsicMDP(intrinsic_reward=vector, mdp=mdp)
            o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=_make_mini_mdp_option_policy(mdp_, n_iters=100),
                       term_prob=0.0)
        elif policy == 'dqn':
            o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=_make_dqn_option_policy(mdp,
                                                      vectors[i],
                                                      n_trajs=n_trajs,
                                                      n_steps=n_steps),
                       term_prob=0.0)
        else:
            print('Error: unknown policy for options:', policy)
            assert (False)

            # policy=_make_mini_mdp_option_policy(mdp),
        options.add(o)

    return options
def find_betweenness_options(mdp, t=0.1, init_everywhere=False):
    T, state_to_id, id_to_state = get_transition_matrix(mdp)

    # print("T=", T)
    G = nx.from_numpy_matrix(T)
    N = G.number_of_nodes()
    M = G.number_of_edges()
    # print("nodes=", N)
    # print("edges=", M)

    #########################
    ## 1. Enumerate all candidate subgoals
    #########################
    subgoal_set = []
    for s in G.nodes():
        # print("s=", s)
        csv = nx.betweenness_centrality_subset(G, sources=[s], targets=G.nodes())
        # csv = nx.betweenness_centrality(G)
        # print("csv=", csv)
        for v in csv:
            if (s is not v) and (csv[v] / (N-2) > t) and (v not in subgoal_set):
                subgoal_set.append(v)

    # for s in subgoal_set:
    #     print(s, " is subgoal")
    # n_subgoals = sum(subgoal_set)
    # print(n_subgoals, "goals in total")
    # centralities = nx.betweenness_centrality(G)
    # for n in centralities:
    #     print("centrality=", centralities[n])

    #########################
    ## 2. Generate an initiation set for each subgoal
    #########################
    initiation_sets = defaultdict(list)
    support_scores = defaultdict(float)
    
    for g in subgoal_set:
        csg = nx.betweenness_centrality_subset(G, sources=G.nodes(), targets=[g])
        score = 0
        for s in G.nodes():
            if csg[s] / (N-2) > t:
                initiation_sets[g].append(s)
                score += csg[s]
        support_scores[g] = score
                
    # for g in subgoal_set:
    #     print("init set for ", g, " = ", initiation_sets[g])

    #########################
    ## 3. Filter subgoals according to their supports
    #########################
    filtered_subgoals = []

    subgoal_graph = G.subgraph(subgoal_set)
    
    sccs = nx.connected_components(subgoal_graph) # TODO: connected components are used instead of SCCs
    # sccs = nx.strongly_connected_components(G)
    for scc in sccs:
        scores = []
        goals = []
        for n in scc:
            scores.append(support_scores[n])
            goals.append(n)
            # print("score of ", n, " = ", support_scores[n])
        # scores = [support_scores[x] for x in scc]
        best_score = max(scores)
        best_goal = goals[scores.index(best_score)]
        filtered_subgoals.append(best_goal)

    options = []
    for g in filtered_subgoals:
        init_set_nums = initiation_sets[g]
        goal_set_nums = [g]
        init_set = [id_to_state[s] for s in init_set_nums]
        goal_set = [id_to_state[s] for s in goal_set_nums]


        # Define predicates.
        if init_everywhere:
            # Initiate everywhere.
            init_predicate = Predicate(lambda x:True)
        else:
            # Terminate everywhere
            init_predicate = InListPredicate(ls=init_set)
        term_predicate = InListPredicate(ls=goal_set)

        between_o = Option(init_predicate=init_predicate,
                       term_predicate=term_predicate,
                       policy=make_option_policy(mdp, id_to_state.values(), goal_set))

        options.append(between_o)

    return options