コード例 #1
0
ファイル: aa_helpers.py プロジェクト: david-abel/simple_rl
def make_goal_based_options(mdp_distr):
    '''
    Args:
        mdp_distr (MDPDistribution)

    Returns:
        (list): Contains Option instances.
    '''

    goal_list = set([])
    for mdp in mdp_distr.get_all_mdps():
        vi = ValueIteration(mdp)
        state_space = vi.get_states()
        for s in state_space:
            if s.is_terminal():
                goal_list.add(s)

    options = set([])
    for mdp in mdp_distr.get_all_mdps():

        init_predicate = Predicate(func=lambda x: True)
        term_predicate = InListPredicate(ls=goal_list)
        o = Option(init_predicate=init_predicate,
                    term_predicate=term_predicate,
                    policy=_make_mini_mdp_option_policy(mdp),
                    term_prob=0.0)
        options.add(o)

    return options
コード例 #2
0
def compute_avg_mdp(mdp_distr, sample_rate=5):
    '''
    Args:
        mdp_distr (defaultdict)

    Returns:
        (MDP)
    '''

    # Get normal components.
    init_state = mdp_distr.get_init_state()
    actions = mdp_distr.get_actions()
    gamma = mdp_distr.get_gamma()
    T = mdp_distr.get_all_mdps()[0].get_transition_func()

    # Compute avg reward.
    avg_rew = defaultdict(lambda : defaultdict(float))
    avg_trans_counts = defaultdict(lambda : defaultdict(lambda : defaultdict(float))) # Stores T_i(s,a,s') * Pr(M_i)
    for mdp in mdp_distr.get_mdps():
        prob_of_mdp = mdp_distr.get_prob_of_mdp(mdp)

        # Get a vi instance to compute state space.
        vi = ValueIteration(mdp, delta=0.0001, max_iterations=2000, sample_rate=sample_rate)
        iters, value = vi.run_vi()
        states = vi.get_states()

        for s in states:
            for a in actions:
                r = mdp.reward_func(s,a)

                avg_rew[s][a] += prob_of_mdp * r
            
                for repeat in xrange(sample_rate):
                    s_prime = mdp.transition_func(s,a)
                    avg_trans_counts[s][a][s_prime] += prob_of_mdp

    avg_trans_probs = defaultdict(lambda : defaultdict(lambda : defaultdict(float)))
    for s in avg_trans_counts.keys():
        for a in actions:
            for s_prime in avg_trans_counts[s][a].keys():
                avg_trans_probs[s][a][s_prime] = avg_trans_counts[s][a][s_prime] / sum(avg_trans_counts[s][a].values())

    def avg_rew_func(s,a):
        return avg_rew[s][a]

    avg_trans_func = T
    # def avg_trans_func(s,a):
    #     s_prime_index = list(np.random.multinomial(1, avg_trans_probs[s][a].values())).index(1)
    #     s_prime = avg_trans_probs[s][a].keys()[s_prime_index]
    #     s_prime.set_terminal(False)
    #     return s_prime

    avg_mdp = MDP(actions, avg_trans_func, avg_rew_func, init_state, gamma)

    return avg_mdp
コード例 #3
0
 def __init__(self,
              mdp,
              name="value_iter",
              delta=0.0001,
              max_iterations=500,
              sample_rate=3):
     ValueIteration.__init__(self, mdp, name, delta, max_iterations,
                             sample_rate)
     # Including for clarity. OptionsMDPValueIteration gets actions from its
     # MDP instance, and not from the self.actions variable in the Planner class.
     self.actions = None
コード例 #4
0
def main():
    import OptimalBeliefAgentClass

    # Setup multitask setting.
    # R ~ D : Puddle, Rock Sample
    # G ~ D : octo, four_room
    # T ~ D : grid

    mdp_class, is_goal_terminal, samples = parse_args()

    mdp_distr = make_mdp_distr(mdp_class=mdp_class,
                               is_goal_terminal=is_goal_terminal)
    mdp_distr.set_gamma(0.99)
    actions = mdp_distr.get_actions()

    # Compute average MDP.
    print "Making and solving avg MDP...",
    sys.stdout.flush()
    avg_mdp = compute_avg_mdp(mdp_distr)
    avg_mdp_vi = ValueIteration(avg_mdp,
                                delta=0.001,
                                max_iterations=1000,
                                sample_rate=5)
    iters, value = avg_mdp_vi.run_vi()
    print "done."  #, iters, value
    sys.stdout.flush()

    # Agents.
    print "Making agents...",
    sys.stdout.flush()
    mdp_distr_copy = copy.deepcopy(mdp_distr)
    opt_stoch_policy = compute_optimal_stoch_policy(mdp_distr_copy)
    opt_stoch_policy_agent = FixedPolicyAgent(opt_stoch_policy,
                                              name="$\pi_{prior}$")
    opt_belief_agent = OptimalBeliefAgentClass.OptimalBeliefAgent(
        mdp_distr, actions)
    vi_agent = FixedPolicyAgent(avg_mdp_vi.policy, name="$\pi_{avg}$")
    rand_agent = RandomAgent(actions, name="$\pi^u$")
    ql_agent = QLearningAgent(actions)
    print "done."

    agents = [vi_agent, opt_stoch_policy_agent, rand_agent, opt_belief_agent]

    # Run task.
    run_agents_multi_task(agents,
                          mdp_distr,
                          task_samples=samples,
                          episodes=1,
                          steps=100,
                          reset_at_terminal=False,
                          track_disc_reward=False,
                          cumulative_plot=True)
コード例 #5
0
    def __init__(self, mdp, transition_func, reward_func, observation_func, updater_type='discrete'):
        '''
        Args:
            mdp (POMDP)
            transition_func: T(s, a) --> s'
            reward_func: R(s, a) --> float
            observation_func: O(s, a) --> z
            updater_type (str)
        '''
        self.reward_func = reward_func
        self.updater_type = updater_type

        # We use the ValueIteration class to construct the transition and observation probabilities
        self.vi = ValueIteration(mdp, sample_rate=500)

        self.transition_probs = self.construct_transition_matrix(transition_func)
        self.observation_probs = self.construct_observation_matrix(observation_func, transition_func)

        if updater_type == 'discrete':
            self.updater = self.discrete_filter_updater
        elif updater_type == 'kalman':
            self.updater = self.kalman_filter_updater
        elif updater_type == 'particle':
            self.updater = self.particle_filter_updater
        else:
            raise AttributeError('updater_type {} did not conform to expected type'.format(updater_type))
コード例 #6
0
ファイル: aa_helpers.py プロジェクト: david-abel/simple_rl
def _make_mini_mdp_option_policy(mini_mdp):
    '''
    Args:
        mini_mdp (MDP)

    Returns:
        Policy
    '''
    # Solve the MDP defined by the terminal abstract state.
    mini_mdp_vi = ValueIteration(mini_mdp, delta=0.001, max_iterations=1000, sample_rate=10)
    iters, val = mini_mdp_vi.run_vi()

    o_policy_dict = make_dict_from_lambda(mini_mdp_vi.policy, mini_mdp_vi.get_states())
    o_policy = PolicyFromDict(o_policy_dict)

    return o_policy.get_action
コード例 #7
0
def _make_mini_mdp_option_policy(mini_mdp, initiating_states):
    '''
    Args:
        mini_mdp (MDP)

    Returns:
        Policy
    '''
    # Solve the MDP defined by the terminal abstract state.
    if isinstance(mini_mdp, OptionsMDP):
        mini_mdp_vi = OptionsMDPValueIteration(mini_mdp,
                                               delta=0.005,
                                               max_iterations=1000,
                                               sample_rate=30)
    else:
        mini_mdp_vi = ValueIteration(mini_mdp,
                                     delta=0.005,
                                     max_iterations=1000,
                                     sample_rate=30)

    for s_g in initiating_states:
        if s_g.is_terminal():
            return lambda s: random.choice(mini_mdp.get_actions(s_g)
                                           ), mini_mdp_vi

    iters, val = mini_mdp_vi.run_vi()
    o_policy_dict = make_dict_from_lambda(
        mini_mdp_vi.policy,
        mini_mdp_vi.get_states() + initiating_states)
    o_policy = PolicyFromDict(o_policy_dict)

    return o_policy.get_action, mini_mdp_vi
コード例 #8
0
def plan_with_vi(gamma=0.99):
    '''
    Args:
        gamma (float): discount factor

    Running value iteration on the problem to test the correctness of the policy returned by BSS
    '''
    mdp = GridWorldMDP(gamma=gamma, goal_locs=[(4, 3)], slip_prob=0.0)
    value_iter = ValueIteration(mdp, sample_rate=5)
    value_iter.run_vi()

    action_seq, state_seq = value_iter.plan(mdp.get_init_state())

    print "[ValueIteration] Plan for {}".format(mdp)
    for i in range(len(action_seq)):
        print 'pi({}) --> {}'.format(state_seq[i], action_seq[i])
コード例 #9
0
def make_random_sa_stack(mdp_distr, cluster_size_ratio=0.5, max_num_levels=2):
    '''
    Args:
        mdp_distr (MDPDistribution)
        cluster_size_ratio (float): A float in (0,1) that determines the size of the abstract state space.
        max_num_levels (int): Determines the _total_ number of levels in the hierarchy (includes ground).

    Returns:
        (StateAbstraction)
    '''

    # Get ground state space.
    vi = ValueIteration(mdp_distr.get_all_mdps()[0],
                        delta=0.0001,
                        max_iterations=5000)
    ground_state_space = vi.get_states()
    sa_stack = StateAbstractionStack(list_of_phi=[])

    # Each loop adds a stack.
    for i in range(max_num_levels - 1):

        # Grab curent state space (at level i).
        cur_state_space = _get_level_i_state_space(ground_state_space,
                                                   sa_stack, i)
        cur_state_space_size = len(cur_state_space)

        if int(cur_state_space_size / cluster_size_ratio) <= 1:
            # The abstract is as small as it can get.
            break

        # Add the mapping.
        new_phi = {}
        for s in cur_state_space:
            new_phi[s] = HierarchyState(data=random.randint(
                1, max(int(cur_state_space_size * cluster_size_ratio), 1)),
                                        level=i + 1)

        if len(set(new_phi.values())) <= 1:
            # The abstract is as small as it can get.
            break

        # Add the sa to the stack.
        sa_stack.add_phi(new_phi)

    return sa_stack
コード例 #10
0
def run(task_block, task_room, red_pos, green_pos, blue_pos, drone_pos, pub,
        drone_path):
    """
    Assume the block is on the floor of each cell
    Get initial pos of drone from caller
    """
    height = 2  # vertical space
    task = DroneTask(task_block, task_room)
    room1 = DroneRoom("room1", [(x, y, z) for x in range(4) for y in range(1)
                                for z in range(height)], "red")
    room2 = DroneRoom("room2", [(x, y, z) for x in range(0, 2)
                                for y in range(2, 4) for z in range(height)],
                      color="green")
    room3 = DroneRoom("room3", [(x, y, z) for x in range(3, 4)
                                for y in range(2, 4) for z in range(height)],
                      color="blue")
    block1 = DroneBlock("block1",
                        red_pos[0],
                        red_pos[1],
                        red_pos[2] - 1,
                        color="red")
    block2 = DroneBlock("block2",
                        green_pos[0],
                        green_pos[1],
                        green_pos[2] - 1,
                        color="green")
    block3 = DroneBlock("block3",
                        blue_pos[0],
                        blue_pos[1],
                        blue_pos[2] - 1,
                        color="blue")
    rooms = [room1, room2, room3]
    blocks = [block1, block2, block3]
    doors = [DroneDoor(1, 1, height), DroneDoor(3, 1, height)]
    mdp = DroneMDP(drone_pos, task, rooms=rooms, blocks=blocks, doors=doors)

    print("Start Value Iteration")
    vi = ValueIteration(mdp)
    vi.run_vi()
    action_seq, state_seq = vi.plan(mdp.init_state)
    policy = defaultdict()
    for i in range(len(action_seq)):
        policy[state_seq[i]] = action_seq[i]
    print("Start Flying")
    mdp.send_path(policy, pub, drone_path)
コード例 #11
0
    def get_policy(self, mdp, verbose=False):
        '''
        Args:
            mdp (MDP): MDP (same level as the current Policy Generator)
        Returns:
            policy (defaultdict): optimal policy in mdp
        '''
        vi = ValueIteration(mdp, sample_rate=1)
        vi.run_vi()

        policy = defaultdict()
        action_seq, state_seq = vi.plan(mdp.init_state)

        if verbose: print('Plan for {}:'.format(mdp))
        for i in range(len(action_seq)):
            if verbose:
                print("\tpi[{}] -> {}".format(state_seq[i], action_seq[i]))
            policy[state_seq[i]] = action_seq[i]
        return policy
コード例 #12
0
def _make_mini_mdp_option_policy(mini_mdp):
    '''
    Args:
        mini_mdp (MDP)

    Returns:
        Policy
    '''
    # Solve the MDP defined by the terminal abstract state.
    mini_mdp_vi = ValueIteration(mini_mdp,
                                 delta=0.005,
                                 max_iterations=1000,
                                 sample_rate=30)
    iters, val = mini_mdp_vi.run_vi()

    o_policy_dict = make_dict_from_lambda(mini_mdp_vi.policy,
                                          mini_mdp_vi.get_states())
    o_policy = PolicyFromDict(o_policy_dict)

    return o_policy.get_action, mini_mdp_vi
    def __init__(self,
                 ground_mdp,
                 state_abstr=None,
                 action_abstr=None,
                 vi_sample_rate=5,
                 max_iterations=1000,
                 amdp_sample_rate=5,
                 delta=0.001):
        '''
        Args:
            ground_mdp (simple_rl.MDP)
            state_abstr (simple_rl.StateAbstraction)
            action_abstr (simple_rl.ActionAbstraction)
            vi_sample_rate (int): Num samples per transition for running VI.
            max_iterations (int): Usual VI # Iteration bound.
            amdp_sample_rate (int): Num samples per abstract transition to use for computing R_abstract, T_abstract.
        '''
        self.ground_mdp = ground_mdp

        # Grab ground state space.
        vi = ValueIteration(self.ground_mdp,
                            delta=0.001,
                            max_iterations=1000,
                            sample_rate=5)
        state_space = vi.get_states()

        # Make the abstract MDP.
        self.state_abstr = state_abstr if state_abstr is not None else StateAbstraction(
            ground_state_space=state_space)
        self.action_abstr = action_abstr if action_abstr is not None else ActionAbstraction(
            prim_actions=ground_mdp.get_actions())
        abstr_mdp = abstr_mdp_funcs.make_abstr_mdp(
            ground_mdp,
            self.state_abstr,
            self.action_abstr,
            step_cost=0.0,
            sample_rate=amdp_sample_rate)

        # Create VI with the abstract MDP.
        ValueIteration.__init__(self, abstr_mdp, vi_sample_rate, delta,
                                max_iterations)
コード例 #14
0
def compute_optimistic_q_function(mdp_distr, sample_rate=5):
    '''
    Instead of transferring an average Q-value, we transfer the highest Q-value in MDPs so that
    it will not under estimate the Q-value.
    '''
    opt_q_func = defaultdict(lambda: defaultdict(lambda: float("-inf")))
    for mdp in mdp_distr.get_mdps():
        # prob_of_mdp = mdp_distr.get_prob_of_mdp(mdp)

        # Get a vi instance to compute state space.
        vi = ValueIteration(mdp,
                            delta=0.0001,
                            max_iterations=1000,
                            sample_rate=sample_rate)
        iters, value = vi.run_vi()
        q_func = vi.get_q_function()
        # print "value =", value
        for s in q_func:
            for a in q_func[s]:
                opt_q_func[s][a] = max(opt_q_func[s][a], q_func[s][a])
    return opt_q_func
コード例 #15
0
def compute_sub_opt_func_for_mdp_distr(mdp_distr):
    '''
    Args:
        mdp_distr (dict)

    Returns:
        (list): Contains the suboptimality function for each MDP in mdp_distr.
            subopt: V^*(s) - Q^(s,a)
    '''
    actions = mdp_distr.get_actions()
    sub_opt_funcs = []

    i = 0
    for mdp in mdp_distr.get_mdps():
        print "\t mdp", i + 1, "of", mdp_distr.get_num_mdps()
        vi = ValueIteration(mdp, delta=0.001, max_iterations=1000)
        iters, value = vi.run_vi()

        new_sub_opt_func = defaultdict(float)
        for s in vi.get_states():
            max_q = float("-inf")
            for a in actions:
                next_q = vi.get_q_value(s, a)
                if next_q > max_q:
                    max_q = next_q

            for a in actions:
                new_sub_opt_func[(s, a)] = max_q - vi.get_q_value(s, a)

        sub_opt_funcs.append(new_sub_opt_func)
        i += 1

    return sub_opt_funcs
コード例 #16
0
def get_distance(mdp, epsilon=0.05):

    vi = ValueIteration(mdp)
    vi.run_vi()
    vstar = vi.value_func  # dictionary of state -> float

    states = vi.get_states()  # list of state

    distance = defaultdict(lambda: defaultdict(float))

    v_df = ValueIterationDist(mdp, vstar)
    v_df.run_vi()
    d_to_s = v_df.distance
    for t in states:
        for s in states:
            distance[t][s] = max(d_to_s[t] - 1, 0)

    for s in states:  # s: state
        vis = ValueIterationDist(mdp, vstar)
        vis.add_fixed_val(s, vstar[s])
        vis.run_vi()
        d_to_s = vis.distance
        for t in states:
            distance[t][s] = min(d_to_s[t], distance[t][s])

    sToInd = OrderedDict()
    indToS = OrderedDict()
    for i, s in enumerate(states):
        sToInd[s] = i
        indToS[i] = s

    d = np.zeros((len(states), len(states)), dtype=int)
    # print "type(d)=", type(d)
    # print "d.shape=", d.shape
    for s in states:
        for t in states:
            # print 's, t=', index[s], index[t]
            d[sToInd[s]][sToInd[t]] = distance[s][t]

    return sToInd, indToS, d
コード例 #17
0
def main():

    # ========================
    # === Make Environment ===
    # ========================
    mdp_class = "four_room"
    environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, grid_dim=10)
    actions = environment.get_actions()

    # ==========================
    # === Make SA, AA Stacks ===
    # ==========================
    # sa_stack, aa_stack = aa_stack_h.make_random_sa_diropt_aa_stack(environment, max_num_levels=3)
    sa_stack, aa_stack = hierarchy_helpers.make_hierarchy(environment,
                                                          num_levels=3)

    mdp = environment.sample()
    HVI = HierarchicalValueIteration(mdp, sa_stack, aa_stack)
    VI = ValueIteration(mdp)

    h_iters, h_val = HVI.run_vi()
    iters, val = VI.run_vi()
コード例 #18
0
def compute_optimal_stoch_policy(mdp_distr):
    '''
    Args:
        mdp_distr (defaultdict)

    Returns:
        (lambda)
    '''

    # Key: state
    # Val: dict
    # Key: action
    # Val: probability
    policy_dict = defaultdict(lambda: defaultdict(float))

    # Compute optimal policy for each MDP.
    for mdp in mdp_distr.get_all_mdps():
        # Solve the MDP and get the optimal policy.
        vi = ValueIteration(mdp, delta=0.001, max_iterations=1000)
        iters, value = vi.run_vi()
        vi_policy = vi.policy
        states = vi.get_states()

        # Compute the probability each action is optimal in each state.
        prob_of_mdp = mdp_distr.get_prob_of_mdp(mdp)
        for s in states:
            a_star = vi_policy(s)
            policy_dict[s][a_star] += prob_of_mdp

    # Create the lambda.
    def policy_from_dict(state):
        action_id = np.random.multinomial(
            1, policy_dict[state].values()).tolist().index(1)
        action = policy_dict[state].keys()[action_id]

        return action

    return policy_from_dict
コード例 #19
0
    def __init__(self,
                 mdp,
                 transition_func,
                 reward_func,
                 observation_func,
                 updater_type='discrete'):
        '''
        Args:
            mdp (POMDP)
            transition_func: T(s, a) --> s'
            reward_func: R(s, a) --> float
            observation_func: O(s, a) --> z
            updater_type (str)
        '''
        self.reward_func = reward_func
        self.updater_type = updater_type

        # We use the ValueIteration class to construct the transition and
        # observation probabilities
        self.vi = ValueIteration(mdp, sample_rate=100)

        print("Constructing transition matrix")
        self.transition_probs = self.construct_transition_matrix(
            transition_func)
        print("Constructing observation matrix")
        self.observation_probs = self.construct_observation_matrix(
            observation_func, transition_func)

        if updater_type == 'discrete':
            self.updater = self.discrete_filter_updater
        elif updater_type == 'kalman':
            self.updater = self.kalman_filter_updater
        elif updater_type == 'particle':
            self.updater = self.particle_filter_updater
        else:
            raise AttributeError(
                'updater_type {} did not conform to expected type'.format(
                    updater_type))
コード例 #20
0
ファイル: sa_helpers.py プロジェクト: david-abel/simple_rl
def make_multitask_sa(mdp_distr, state_class=State, indic_func=ind_funcs._q_eps_approx_indicator, epsilon=0.0, aa_single_act=True, track_act_opt_pr=False):
    '''
    Args:
        mdp_distr (MDPDistribution)
        state_class (Class)
        indicator_func (S x S --> {0,1})
        epsilon (float)
        aa_single_act (bool): If we should track optimal actions.

    Returns:
        (StateAbstraction)
    '''
    sa_list = []
    for mdp in mdp_distr.get_mdps():
        sa = make_singletask_sa(mdp, indic_func, state_class, epsilon, aa_single_act=aa_single_act, prob_of_mdp=mdp_distr.get_prob_of_mdp(mdp), track_act_opt_pr=track_act_opt_pr)
        sa_list += [sa]

    mdp = mdp_distr.get_all_mdps()[0]
    vi = ValueIteration(mdp)
    ground_states = vi.get_states()
    multitask_sa = merge_state_abstr(sa_list, ground_states)

    return multitask_sa
コード例 #21
0
    def __init__(self, ground_mdp, state_abstr=None, action_abstr=None, vi_sample_rate=5, max_iterations=1000, amdp_sample_rate=5, delta=0.001):
        '''
        Args:
            ground_mdp (simple_rl.MDP)
            state_abstr (simple_rl.StateAbstraction)
            action_abstr (simple_rl.ActionAbstraction)
            vi_sample_rate (int): Num samples per transition for running VI.
            max_iterations (int): Usual VI # Iteration bound.
            amdp_sample_rate (int): Num samples per abstract transition to use for computing R_abstract, T_abstract.
        '''
        self.ground_mdp = ground_mdp
    
        # Grab ground state space.
        vi = ValueIteration(self.ground_mdp, delta=0.001, max_iterations=1000, sample_rate=5)
        state_space = vi.get_states()

        # Make the abstract MDP.
        self.state_abstr = state_abstr if state_abstr is not None else StateAbstraction(ground_state_space=state_space)
        self.action_abstr = action_abstr if action_abstr is not None else ActionAbstraction(prim_actions=ground_mdp.get_actions())
        abstr_mdp = abstr_mdp_funcs.make_abstr_mdp(ground_mdp, self.state_abstr, self.action_abstr, step_cost=0.0, sample_rate=amdp_sample_rate)

        # Create VI with the abstract MDP.
        ValueIteration.__init__(self, abstr_mdp, vi_sample_rate, delta, max_iterations)
コード例 #22
0
 def update_init_q_function(self, mdp):
     if self.task_number == 0:
         self.default_q_func = copy.deepcopy(self.default_q_func)
     elif self.task_number < self.num_sample_tasks:
         new_q_func = self.q_func
         for x in new_q_func:
             for y in new_q_func[x]:
                 self.default_q_func[x][y] = max(new_q_func[x][y],
                                                 self.default_q_func[x][y])
     elif self.task_number == self.num_sample_tasks:
         vi = ValueIteration(mdp,
                             delta=0.1,
                             max_iterations=2,
                             sample_rate=1)
         _, _ = vi.run_vi()
         new_q_func = vi.get_q_function()  # VI to enumerate all states
         for s in new_q_func:
             for a in new_q_func[s]:
                 if self.default_q_func[s][
                         a] < 0:  # If (s, a) is never visited set Vmax
                     self.default_q_func[s][a] = self.default_q
         print(self.name, "Initial Q func from", self.task_number, "tasks")
         self.print_dict(self.default_q_func)
コード例 #23
0
    def planFromAtoB(self, Maps, nearestVertex, kStepConfig):

        # if not self.computedMDP:
        #     self.wallLocations = []
        #     for x in range(len(self.Maps.occupancyMap)):
        #         for y in range(len(self.Maps.occupancyMap[x])):
        #             if self.Maps.occupancyMap[x][y] == Env.WALL:
        #                 self.wallLocations.append(Loc.Location(x,y))
        #     self.computedMDP = True

        mdp = GridWorldMDP(width=len(Maps.occupancyMap),
                           height=len(Maps.occupancyMap[0]),
                           init_loc=(nearestVertex.x, nearestVertex.y),
                           goal_locs=[(kStepConfig.x, kStepConfig.y)],
                           gamma=0.95)
        vi = ValueIteration(mdp)
        vi.run_vi()
        action_seq, state_seq = vi.plan()

        #check if conflict
        for s in state_seq:
            if Maps.occupancyMap[s[0], s[1]] == env.WALL:
                return False
        return True
コード例 #24
0
    def __init__(self):
        self.base_human_model = PuddleMDP(step_cost=1.0)
        self.base_agent = ValueIteration(self.base_human_model,
                                         max_iterations=5000,
                                         sample_rate=1)
        self.sample_agent = ModQLearningAgent(
            actions=self.base_human_model.get_actions(),
            epsilon=0.5,
            anneal=True)
        #run_single_agent_on_mdp(self.base_agent, self.base_human_model, episodes=10000, steps=60, verbose=True)
        self.base_agent.run_vi()

        #print ("Q func", self.base_agent.q_func)
        self.test_run = False

        if self.test_run:
            self.novice_model_1 = self.base_human_model
            self.novice_model_2 = self.base_human_model
            self.fully_actulized_model = self.base_human_model

            self.novice_agent_1 = self.base_agent
            self.novice_agent_2 = self.base_agent
            self.fully_actulized_agent = self.base_agent
        else:

            self.novice_model_1 = PuddleMDP2(step_cost=1.0)
            self.novice_agent_1 = ValueIteration(self.novice_model_1)
            self.novice_agent_1.run_vi()

            self.novice_model_2 = PuddleMDP3(step_cost=1.0)
            self.novice_agent_2 = ValueIteration(self.novice_model_2)
            self.novice_agent_2.run_vi()

            self.fully_actulized_model = PuddleMDP4(step_cost=1.0)
            self.fully_actulized_agent = ValueIteration(
                self.fully_actulized_model)
            self.fully_actulized_agent.run_vi()
            #self.fully_actulized_agent = ModQLearningAgent(actions=self.fully_actulized_model.get_actions(), epsilon=0.5, anneal=True)
            #run_single_agent_on_mdp(self.fully_actulized_agent, self.fully_actulized_model, episodes=10000, steps=60, verbose=True)

        # TODO Add other settings

        self.current_agent = self.base_agent
        self.current_mdp = self.base_human_model
コード例 #25
0
class PUDDLER:
    def __init__(self):
        self.base_human_model = PuddleMDP(step_cost=1.0)
        self.base_agent = ValueIteration(self.base_human_model,
                                         max_iterations=5000,
                                         sample_rate=1)
        self.sample_agent = ModQLearningAgent(
            actions=self.base_human_model.get_actions(),
            epsilon=0.5,
            anneal=True)
        #run_single_agent_on_mdp(self.base_agent, self.base_human_model, episodes=10000, steps=60, verbose=True)
        self.base_agent.run_vi()

        #print ("Q func", self.base_agent.q_func)
        self.test_run = False

        if self.test_run:
            self.novice_model_1 = self.base_human_model
            self.novice_model_2 = self.base_human_model
            self.fully_actulized_model = self.base_human_model

            self.novice_agent_1 = self.base_agent
            self.novice_agent_2 = self.base_agent
            self.fully_actulized_agent = self.base_agent
        else:

            self.novice_model_1 = PuddleMDP2(step_cost=1.0)
            self.novice_agent_1 = ValueIteration(self.novice_model_1)
            self.novice_agent_1.run_vi()

            self.novice_model_2 = PuddleMDP3(step_cost=1.0)
            self.novice_agent_2 = ValueIteration(self.novice_model_2)
            self.novice_agent_2.run_vi()

            self.fully_actulized_model = PuddleMDP4(step_cost=1.0)
            self.fully_actulized_agent = ValueIteration(
                self.fully_actulized_model)
            self.fully_actulized_agent.run_vi()
            #self.fully_actulized_agent = ModQLearningAgent(actions=self.fully_actulized_model.get_actions(), epsilon=0.5, anneal=True)
            #run_single_agent_on_mdp(self.fully_actulized_agent, self.fully_actulized_model, episodes=10000, steps=60, verbose=True)

        # TODO Add other settings

        self.current_agent = self.base_agent
        self.current_mdp = self.base_human_model

    def get_init_info(self):
        data_points = []
        return data_points

    def get_human_reinf_from_prev_step(self,
                                       state,
                                       action,
                                       explanation_features=[0, 0]):
        delta = 0.1
        print(explanation_features)
        if explanation_features[1] == 1 and explanation_features[0] == 1:
            self.current_mdp = self.fully_actulized_model
            self.current_agent = self.fully_actulized_agent
        elif explanation_features[0] == 1:
            self.current_mdp = self.novice_model_1
            self.current_agent = self.novice_agent_1
        elif explanation_features[1] == 1:
            self.current_mdp = self.novice_model_2
            self.current_agent = self.novice_agent_2
        else:
            self.current_mdp = self.base_human_model
            self.current_agent = self.base_agent

        curr_best_q_val = self.current_agent.get_value(state)
        curr_q_val = self.current_agent.get_q_value(state, action)
        #        return curr_q_val - curr_best_q_val
        return min((float(curr_best_q_val - curr_q_val) + delta) /
                   (float(curr_best_q_val) + delta), 1)

    def get_possible_actions(self):
        return self.base_human_model.get_actions()

    def get_best_action(self, state, explanation_features=[0, 0]):
        if explanation_features[1] == 1 and explanation_features[0] == 1:
            self.current_mdp = self.fully_actulized_model
            self.current_agent = self.fully_actulized_agent
        elif explanation_features[0] == 1:
            self.current_mdp = self.novice_model_1
            self.current_agent = self.novice_agent_1
        elif explanation_features[1] == 1:
            self.current_mdp = self.novice_model_2
            self.current_agent = self.novice_agent_2
        else:
            self.current_mdp = self.base_human_model
            self.current_agent = self.base_agent

        return self.current_agent._get_max_q_action(state)

    def get_initial_state(self):
        # TODO Randomize
        return self.base_human_model.get_init_state()

    def get_initial_state_features(self):
        return self.base_human_model.get_init_state().features()

    def get_next_state(self, state, act, explanation_features=[0]):
        if explanation_features[0] >= 0.5:
            self.current_mdp = self.fully_actulized_model
            self.current_agent = self.fully_actulized_agent
        else:
            self.current_mdp = self.base_human_model
            self.current_agent = self.base_agent

        self.current_mdp.set_state(state)
        reward, new_state = self.current_mdp.execute_agent_action(act)
        return new_state

    def set_state(self, x, y):
        state = GridWorldState(x, y)
        self.base_human_model.set_state(state)
        return state

    def visualize_agent(self, state):
        self.base_human_model.set_state(state)
        self.base_human_model.visualize_state(self.sample_agent)
コード例 #26
0
    def __init__(self,
                 ground_mdp,
                 state_abstr=None,
                 action_abstr=None,
                 sample_rate=10,
                 delta=0.001,
                 max_iterations=1000):
        '''
        Args:
            ground_mdp (MDP)
            state_abstr (StateAbstraction)
            action_abstr (ActionAbstraction)
        '''
        self.ground_mdp = ground_mdp
        self.state_abstr = state_abstr if state_abstr not in [
            [], None
        ] else StateAbstraction()
        self.action_abstr = action_abstr if action_abstr not in [
            [], None
        ] else ActionAbstraction(prim_actions=ground_mdp.get_actions())

        mdp = make_abstr_mdp(ground_mdp, self.state_abstr, self.action_abstr)

        ValueIteration.__init__(self, mdp, sample_rate, delta, max_iterations)


#         self.delta = delta
#         self.max_iterations = max_iterations
#         self.sample_rate = sample_rate

#         self.value_func = defaultdict(float)
#         self.reachability_done = False
#         self.has_run_vi = False
#         self._compute_reachable_state_space()

#     def get_num_states(self):
#         return len(self.states)

#     def get_states(self):
#         if self.reachability_done:
#             return self.states
#         else:
#             self._compute_reachable_state_space()
#             return self.states

#     def _compute_reachable_state_space(self):
#         '''
#         Summary:
#             Starting with @self.start_state, determines all reachable states
#             and stores their abstracted counterparts in self.states.
#         '''
#         state_queue = Queue.Queue()
#         s_g_init = self.mdp.get_init_state()
#         s_a_init = self.state_abstr.phi(s_g_init)
#         state_queue.put(s_g_init)
#         self.states.add(s_a_init)
#         ground_t = self.mdp.get_transition_func()

#         while not state_queue.empty():
#             ground_state = state_queue.get()
#             for option in self.action_abstr.get_active_options(ground_state):
#                 # For each active option.

#                 # Take @sample_rate samples to estimate E[V]
#                 for samples in xrange(self.sample_rate):

#                     next_g_state = option.act_until_terminal(ground_state, ground_t)

#                     if next_g_state not in self.states:
#                         next_a_state = self.state_abstr.phi(next_g_state)
#                         self.states.add(next_a_state)
#                         state_queue.put(next_g_state)

#         self.reachability_done = True

#     def plan(self, ground_state=None, horizon=100):
#         '''
#         Args:
#             ground_state (State)
#             horizon (int)

#         Returns:
#             (tuple):
#                 (list): List of primitive actions taken.
#                 (list): List of ground states.
#                 (list): List of abstract actions taken.
#         '''

#         ground_state = self.mdp.get_init_state() if ground_state is None else ground_state

#         if self.has_run_vi is False:
#             print "Warning: VI has not been run. Plan will be random."

#         primitive_action_seq = []
#         abstr_action_seq = []
#         state_seq = [ground_state]
#         steps = 0

#         ground_t = self.transition_func

#         # Until terminating condition is met.
#         while (not ground_state.is_terminal()) and steps < horizon:

#             # Compute best action, roll it out.
#             next_option = self._get_max_q_action(ground_state)

#             while not next_option.is_term_true(ground_state):
#                 # Keep applying option until it terminates.
#                 abstr_state = self.state_abstr.phi(ground_state)
#                 ground_action = next_option.act(ground_state)
#                 ground_state = ground_t(ground_state, ground_action)
#                 steps += 1
#                 primitive_action_seq.append(ground_action)

#                 state_seq.append(ground_state)

#             abstr_action_seq.append(next_option)

#         return primitive_action_seq, state_seq, abstr_action_seq

#     def run_vi(self):
#         '''
#         Summary:
#             Runs ValueIteration and fills in the self.value_func.
#         '''
#         # Algorithm bookkeeping params.
#         iterations = 0
#         max_diff = float("inf")

#         # Main loop.
#         while max_diff > self.delta and iterations < self.max_iterations:
#             max_diff = 0
#             for s_g in self.get_states():
#                 if s_g.is_terminal():
#                     continue

#                 max_q = float("-inf")
#                 for a in self.action_abstr.get_active_options(s_g):
#                     # For each active option, compute it's q value.
#                     q_s_a = self.get_q_value(s_g, a)
#                     max_q = q_s_a if q_s_a > max_q else max_q

#                 # Check terminating condition.
#                 max_diff = max(abs(self.value_func[s_g] - max_q), max_diff)

#                 # Update value.
#                 self.value_func[s_g] = max_q

#             iterations += 1

#         value_of_init_state = self._compute_max_qval_action_pair(self.init_state)[0]

#         self.has_run_vi = True

#         return iterations, value_of_init_state

#     def get_q_value(self, s_g, option):
#         '''
#         Args:
#             s (State)
#             a (Option): Assumed active option.

#         Returns:
#             (float): The Q estimate given the current value function @self.value_func.
#         '''

#         # Take samples and track next state counts.
#         next_state_counts = defaultdict(int)
#         reward_total = 0
#         for samples in xrange(self.sample_rate): # Take @sample_rate samples to estimate E[V]
#             next_state, reward, num_steps = self.do_rollout(option, s_g)
#             next_state_counts[next_state] += 1
#             reward_total += reward

#         # Compute T(s' | s, option) estimate based on MLE and R(s, option).
#         next_state_probs = defaultdict(float)
#         avg_reward = 0
#         for state in next_state_counts:
#             next_state_probs[state] = float(next_state_counts[state]) / self.sample_rate

#         avg_reward = float(reward_total) / self.sample_rate

#         # Compute expected value.
#         expected_future_val = 0
#         for state in next_state_probs:
#             expected_future_val += next_state_probs[state] * self.value_func[state]

#         return avg_reward + self.gamma*expected_future_val

#     def do_rollout(self, option, ground_state):
#         '''
#         Args:
#             option (Option)
#             ground_state (State)

#         Returns:
#             (tuple):
#                 (State): Next ground state.
#                 (float): Reward.
#                 (int): Number of steps taken.
#         '''

#         ground_t = self.mdp.get_transition_func()
#         ground_r = self.mdp.get_reward_func()

#         if type(option) is str:
#             ground_action = option
#         else:
#             ground_action = option.act(ground_state)
#         total_reward = ground_r(ground_state, ground_action)
#         ground_state = ground_t(ground_state, ground_action)

#         total_steps = 1
#         while type(option) is not str and not option.is_term_true(ground_state):
#             # Keep applying option until it terminates.
#             ground_action = option.act(ground_state)
#             total_reward += ground_r(ground_state, ground_action)
#             ground_state = ground_t(ground_state, ground_action)
#             total_steps += 1

#         return ground_state, total_reward, total_steps

#     def _compute_max_qval_action_pair(self, state):
#         '''
#         Args:
#             state (State)

#         Returns:
#             (tuple) --> (float, str): where the float is the Qval, str is the action.
#         '''
#         # Grab random initial action in case all equal
#         max_q_val = float("-inf")
#         shuffled_option_list = self.action_abstr.get_active_options(state)[:]
#         if len(shuffled_option_list) == 0:
#         	# Prims on failure.
#         	shuffled_option_list = self.mdp.get_actions()

#         random.shuffle(shuffled_option_list)
#         best_action = shuffled_option_list[0]

#         # Find best action (action w/ current max predicted Q value)
#         for option in shuffled_option_list:
#             q_s_a = self.get_q_value(state, option)
#             if q_s_a > max_q_val:
#                 max_q_val = q_s_a
#                 best_action = option

#         return max_q_val, best_action

#     def _get_max_q_action(self, state):
#         '''
#         Args:
#             state (State)

#         Returns:
#             (str): denoting the action with the max q value in the given @state.
#         '''
#         return self._compute_max_qval_action_pair(state)[1]

#     def policy(self, state):
#         '''
#         Args:
#             state (State)

#         Returns:
#             (str): Action

#         Summary:
#             For use in a FixedPolicyAgent.
#         '''
#         return self._get_max_q_action(state)

# def main():
#     # MDP Setting.
#     multi_task = False
#     mdp_class = "grid"

#     # Make single/multi task environment.
#     environment = make_mdp.make_mdp_distr(mdp_class=mdp_class, num_mdps=3, horizon=30) if multi_task else make_mdp.make_mdp(mdp_class=mdp_class)
#     actions = environment.get_actions()
#     gamma = environment.get_gamma()

#     directed_sa, directed_aa = ae.get_abstractions(environment, directed=True)
#     default_sa, default_aa = ae.get_sa(environment, default=True), ae.get_aa(environment, default=True)

#     vi = ValueIteration(environment)
#     avi = AbstractValueIteration(environment, state_abstr=default_sa, action_abstr=default_aa)

#     a_num_iters, a_val = avi.run_vi()
#     g_num_iters, g_val = vi.run_vi()

#     print "a", a_num_iters, a_val
#     print "g", g_num_iters, g_val

# if __name__ == "__main__":
#     main()
コード例 #27
0
def main(eps=0.1, open_plot=True):

    mdp_class, is_goal_terminal, samples, alg = parse_args()

    # Setup multitask setting.
    mdp_distr = make_mdp.make_mdp_distr(mdp_class=mdp_class)
    actions = mdp_distr.get_actions()

    # Compute average MDP.
    print "Making and solving avg MDP...",
    sys.stdout.flush()
    avg_mdp = compute_avg_mdp(mdp_distr)
    avg_mdp_vi = ValueIteration(avg_mdp,
                                delta=0.001,
                                max_iterations=1000,
                                sample_rate=5)
    iters, value = avg_mdp_vi.run_vi()

    ### Yuu

    transfer_fixed_agent = FixedPolicyAgent(avg_mdp_vi.policy,
                                            name="transferFixed")
    rand_agent = RandomAgent(actions, name="$\pi^u$")

    opt_q_func = compute_optimistic_q_function(mdp_distr)
    avg_q_func = avg_mdp_vi.get_q_function()

    if alg == "q":
        pure_ql_agent = QLearnerAgent(actions, epsilon=eps, name="Q-0")
        qmax = 1.0 * (1 - 0.99)
        # qmax = 1.0
        pure_ql_agent_opt = QLearnerAgent(actions,
                                          epsilon=eps,
                                          default_q=qmax,
                                          name="Q-vmax")
        transfer_ql_agent_optq = QLearnerAgent(actions,
                                               epsilon=eps,
                                               name="Q-trans-max")
        transfer_ql_agent_optq.set_init_q_function(opt_q_func)
        transfer_ql_agent_avgq = QLearnerAgent(actions,
                                               epsilon=eps,
                                               name="Q-trans-avg")
        transfer_ql_agent_avgq.set_init_q_function(avg_q_func)

        agents = [
            pure_ql_agent, pure_ql_agent_opt, transfer_ql_agent_optq,
            transfer_ql_agent_avgq
        ]
    elif alg == "rmax":
        pure_rmax_agent = RMaxAgent(actions, name="RMAX-vmax")
        updating_trans_rmax_agent = UpdatingRMaxAgent(actions,
                                                      name="RMAX-updating_max")
        trans_rmax_agent = RMaxAgent(actions, name="RMAX-trans_max")
        trans_rmax_agent.set_init_q_function(opt_q_func)
        agents = [pure_rmax_agent, updating_trans_rmax_agent, trans_rmax_agent]
    elif alg == "delayed-q":
        pure_delayed_ql_agent = DelayedQLearnerAgent(actions,
                                                     opt_q_func,
                                                     name="DelayedQ-vmax")
        pure_delayed_ql_agent.set_vmax()
        updating_delayed_ql_agent = UpdatingDelayedQLearnerAgent(
            actions, name="DelayedQ-updating_max")
        trans_delayed_ql_agent = DelayedQLearnerAgent(
            actions, opt_q_func, name="DelayedQ-trans-max")
        agents = [
            pure_delayed_ql_agent, updating_delayed_ql_agent,
            trans_delayed_ql_agent
        ]
    else:
        print "Unknown type of agents:", alg
        print "(q, rmax, delayed-q)"
        assert (False)

    # Run task.
    # TODO: Function for Learning on each MDP
    run_agents_multi_task(agents,
                          mdp_distr,
                          task_samples=samples,
                          episodes=1,
                          steps=100,
                          reset_at_terminal=is_goal_terminal,
                          is_rec_disc_reward=False,
                          cumulative_plot=True,
                          open_plot=open_plot)
コード例 #28
0
def main():
    # This accepts arguments from the command line with flags.
    # Example usage: python value_iteration_demo.py -w 4 -H 3 -s 0.05 -g 0.95 -il [(0,0)] -gl [(4,3)] -ll [(4,2)]  -W [(2,2)]
    parser = argparse.ArgumentParser(
        description=
        'Run a demo that shows a visualization of value iteration on a GridWorld MDP'
    )

    # Add the relevant arguments to the argparser
    parser.add_argument(
        '-w',
        '--width',
        type=int,
        nargs="?",
        const=5,
        default=5,
        help=
        'an integer representing the number of cells for the GridWorld width')
    parser.add_argument(
        '-H',
        '--height',
        type=int,
        nargs="?",
        const=5,
        default=5,
        help=
        'an integer representing the number of cells for the GridWorld height')
    parser.add_argument(
        '-s',
        '--slip',
        type=float,
        nargs="?",
        const=0.05,
        default=0.05,
        help=
        'a float representing the probability that the agent will "slip" and not take the intended action'
    )
    parser.add_argument(
        '-g',
        '--gamma',
        type=float,
        nargs="?",
        const=0.95,
        default=0.95,
        help='a float representing the decay probability for Value Iteration')
    parser.add_argument(
        '-il',
        '--i_loc',
        type=tuple,
        nargs="?",
        const=(0, 0),
        default=(0, 0),
        help=
        'two integers representing the starting cell location of the agent [with zero-indexing]'
    )
    parser.add_argument(
        '-gl',
        '--g_loc',
        type=list,
        nargs="?",
        const=[(3, 3)],
        default=[(3, 3)],
        help=
        'a sequence of integer-valued coordinates where the agent will receive a large reward and enter a terminal state'
    )
    args = parser.parse_args()
    mdp = generate_MDP(args.width, args.height, args.i_loc, args.g_loc,
                       args.gamma, args.slip)

    # Run value iteration on the mdp and save the history of value backups until convergence
    vi = ValueIteration(mdp, max_iterations=1)
    _, _, histories = vi.run_vi_histories()

    # For every value backup, visualize the policy
    for value_dict in histories:
        #mdp.visualize_policy(lambda in_state: value_dict[in_state]) # Note: This lambda is necessary because the policy must be a function
        #time.sleep(0.5)
        print("========================")
        for k in value_dict.keys(
        ):  # Note: This lambda is necessary because the policy must be a function
            print(str(k) + " " + str(value_dict[k]))
        print(vi.plan(state=mdp.init_state))
コード例 #29
0
def main():
    # This accepts arguments from the command line with flags.
    # Example usage: python value_iteration_update_viz_example.py -w 7 -H 5 -s 0.05 -g 0.95
    #   -il '(1,1)' -gl '[(7,4)]' -ll '[(7,3)]' -W '[(2,2)]'
    # Examples WINDOWS Usage: python value_iteration_update_viz_example.py -w 7 -H 5 -s 0.05
    #   -g 0.95 -il (1,1) -gl [(7,4)] -ll [(7,3)] -W [(2,2)]
    parser = argparse.ArgumentParser(
        description='Run a demo that shows a visualization of value' +
        'iteration on a GridWorld MDP. \n Notes: \n 1.' +
        'Goal states should appear as green circles, lava' +
        ' states should be red circles and the agent start' +
        ' location should appear with a blue triangle. If' +
        ' these are not shown, you have probably passed in' +
        ' a value that is outside the grid \n 2.' +
        'This program is intended to provide a visualization' +
        ' of Value Iteration after every iteration of the algorithm.' +
        ' Once you pass in the correct arguments, a PyGame screen should pop-up.'
        +
        ' Press the esc key to view the next iteration and the q key to quit' +
        '\n 3. The program prints the total time taken for VI to run in seconds '
        + ' and the number of iterations (as the history) to the console.')

    # Add the relevant arguments to the argparser
    parser.add_argument(
        '-w',
        '--width',
        type=int,
        nargs="?",
        const=4,
        default=4,
        help=
        'an integer representing the number of cells for the GridWorld width')
    parser.add_argument(
        '-H',
        '--height',
        type=int,
        nargs="?",
        const=3,
        default=3,
        help=
        'an integer representing the number of cells for the GridWorld height')
    parser.add_argument(
        '-s',
        '--slip',
        type=float,
        nargs="?",
        const=0.05,
        default=0.05,
        help=
        'a float representing the probability that the agent will "slip" and not take the intended action but take a random action at uniform instead'
    )
    parser.add_argument(
        '-g',
        '--gamma',
        type=float,
        nargs="?",
        const=0.95,
        default=0.95,
        help='a float representing the decay factor for Value Iteration')
    parser.add_argument(
        '-il',
        '--i_loc',
        type=ast.literal_eval,
        nargs="?",
        const=(1, 1),
        default=(1, 1),
        help=
        "a tuple of integers representing the starting cell location of the agent, with one-indexing. For example, do -il '(1,1)' , be sure to include apostrophes (unless you use Windows) or argparse will fail!"
    )
    parser.add_argument(
        '-gl',
        '--g_loc',
        type=ast.literal_eval,
        nargs="?",
        const=[(3, 3)],
        default=[(3, 3)],
        help=
        "a list of tuples of of integer-valued coordinates where the agent will receive a large reward and enter a terminal state. Each coordinate is a location on the grid with one-indexing. For example, do -gl '[(3,3)]' , be sure to include apostrophes (unless you use Windows) or argparse will fail!"
    )
    parser.add_argument(
        '-ll',
        '--l_loc',
        type=ast.literal_eval,
        nargs="?",
        const=[(3, 2)],
        default=[(3, 2)],
        help=
        "a list of tuples of of integer-valued coordinates where the agent will receive a large negative reward and enter a terminal state. Each coordinate is a location on the grid with one-indexing. For example, do -ll '[(3,2)]' , be sure to include apostrophes (unless you use Windows) or argparse will fail!"
    )
    parser.add_argument(
        '-W',
        '--Walls',
        type=ast.literal_eval,
        nargs="?",
        const=[(2, 2)],
        default=[(2, 2)],
        help=
        "a list of tuples of of integer-valued coordinates where there are 'walls' that the agent can't transition into. Each coordinate is a location on the grid with one-indexing. For example, do -W '[(3,2)]' , be sure to include apostrophes (unless you use Windows) or argparse will fail!"
    )
    parser.add_argument(
        '-d',
        '--delta',
        type=float,
        nargs="?",
        const=0.0001,
        default=0.0001,
        help=
        'After an iteration if VI, if no change more than delta has occurred, terminates.'
    )
    parser.add_argument('-m',
                        '--max-iter',
                        type=int,
                        nargs="?",
                        const=500,
                        default=500,
                        help='Maximum number of iterations VI runs for')
    parser.add_argument('--skip',
                        action='store_true',
                        help='Skip to last frame or not')

    args = parser.parse_args()
    if args.skip is None:
        args.skip = False

    mdp = generate_MDP(args.width, args.height, args.i_loc, args.g_loc,
                       args.l_loc, args.gamma, args.Walls, args.slip)

    # Run value iteration on the mdp and save the history of value backups until convergence
    st = time.time()
    vi = ValueIteration(mdp, max_iterations=args.max_iter, delta=args.delta)
    num_hist, _, q_act_histories, val_histories = vi.run_vi_histories()
    end = time.time()

    print('Took {:.4f} seconds'.format(end - st))

    # For every value backup, visualize the policy
    if args.skip:
        mdp.visualize_policy_values(
            (lambda in_state: q_act_histories[-1][in_state]),
            (lambda curr_state: val_histories[-1][curr_state]))
    else:
        for i in range(num_hist):
            print('Showing history {:04d} of {:04d}'.format(i + 1, num_hist))
            # Note: This lambda is necessary because the policy must be a function
            mdp.visualize_policy_values(
                (lambda in_state: q_act_histories[i][in_state]),
                (lambda curr_state: val_histories[i][curr_state]))
コード例 #30
0
    def update_init_q_function(self, mdp):
        '''
        If sample_with_q is True, run Q-learning for sample tasks.
        If qstar_transfer is True, run value iteration for sample tasks to get Q*.
        Else, run delayed Q-learning for sample tasks
        '''
        if self.sample_with_q:
            if self.task_number == 0:
                self.init_q_func = copy.deepcopy(self.q_agent.q_func)
            elif self.task_number < self.num_sample_tasks:
                new_q_func = self.q_agent.q_func
                for x in new_q_func:
                    for y in new_q_func[x]:
                        self.init_q_func[x][y] = max(new_q_func[x][y],
                                                     self.init_q_func[x][y])
        elif self.qstar_transfer:
            if self.task_number == 0:
                self.init_q_func = defaultdict(
                    lambda: defaultdict(lambda: float("-inf")))
            # else:
            elif self.task_number < self.num_sample_tasks:
                vi = ValueIteration(mdp,
                                    delta=0.0001,
                                    max_iterations=2000,
                                    sample_rate=5)
                _, _ = vi.run_vi()
                new_q_func = vi.get_q_function()
                for x in new_q_func:
                    for y in new_q_func[x]:
                        self.init_q_func[x][y] = max(new_q_func[x][y],
                                                     self.init_q_func[x][y])
        else:
            if self.task_number == 0:
                self.init_q_func = defaultdict(
                    lambda: defaultdict(lambda: float("-inf")))
            elif self.task_number < self.num_sample_tasks:
                new_q_func = self.q_func
                for x in new_q_func:
                    assert len(self.init_q_func[x]) <= len(new_q_func[x])
                    for y in new_q_func[x]:
                        self.init_q_func[x][y] = max(new_q_func[x][y],
                                                     self.init_q_func[x][y])
                        assert (self.init_q_func[x][y] <= self.default_q)

                ### Uncomment the code below to check if Q-value is converging to the optimal enough
                # Compare q_func learned vs. the true Q value.
                # vi = ValueIteration(mdp, delta=0.001, max_iterations=2000, sample_rate=5)
                # _, _ = vi.run_vi()
                # qstar_func = vi.get_q_function()  # VI to enumerate all states
                # print "Q-function learned by delayed-Q"
                # self.print_dict(new_q_func)
                # print "Optimal Q-function"
                # self.print_dict(qstar_func)

        if self.task_number == self.num_sample_tasks:
            vi = ValueIteration(mdp,
                                delta=0.1,
                                max_iterations=2,
                                sample_rate=1)
            _, _ = vi.run_vi()
            new_q_func = vi.get_q_function()  # VI to enumerate all states
            for s in new_q_func:
                for a in new_q_func[s]:
                    if self.init_q_func[s][
                            a] < 0:  # If (s, a) is never visited set Vmax
                        self.init_q_func[s][a] = self.default_q
            print(self.name, "Initial Q func from", self.task_number, "tasks")
            self.print_dict(self.init_q_func)
コード例 #31
0
def make_singletask_sa(mdp,
                       indic_func,
                       state_class,
                       epsilon=0.0,
                       aa_single_act=False,
                       prob_of_mdp=1.0):
    '''
    Args:
        mdp (MDP)
        indic_func (S x S --> {0,1})
        state_class (Class)
        epsilon (float)

    Returns:
        (StateAbstraction)
    '''

    print "\tRunning VI...",
    sys.stdout.flush()
    # Run VI
    if isinstance(mdp, MDPDistribution):
        mdp = mdp.sample()

    vi = ValueIteration(mdp)
    iters, val = vi.run_vi()
    print " done."

    print "\tMaking state abstraction...",
    sys.stdout.flush()
    sa = StateAbstraction(phi={}, state_class=state_class)
    clusters = defaultdict(set)
    num_states = len(vi.get_states())
    actions = mdp.get_actions()

    # Find state pairs that satisfy the condition.
    for i, state_x in enumerate(vi.get_states()):
        sys.stdout.flush()
        clusters[state_x].add(state_x)

        for state_y in vi.get_states()[i:]:
            if not (state_x == state_y) and indic_func(
                    state_x, state_y, vi, actions, epsilon=epsilon):
                clusters[state_x].add(state_y)
                clusters[state_y].add(state_x)

    print "making clusters...",
    sys.stdout.flush()

    # Build SA.
    for i, state in enumerate(clusters.keys()):
        new_cluster = clusters[state]
        sa.make_cluster(new_cluster)

        # Destroy old so we don't double up.
        for s in clusters[state]:
            if s in clusters.keys():
                clusters.pop(s)

    print " done."
    print "\tGround States:", num_states
    print "\tAbstract:", sa.get_num_abstr_states()
    print

    return sa
コード例 #32
0
def main():

    # Setup environment.
    mdp_class, agent_type, samples = parse_args()
    is_goal_terminal = False
    mdp_distr = make_mdp_distr(mdp_class=mdp_class,
                               is_goal_terminal=is_goal_terminal)
    mdp_distr.set_gamma(0.99)
    actions = mdp_distr.get_actions()

    # Compute priors.

    # Stochastic mixture.
    mdp_distr_copy = copy.deepcopy(mdp_distr)
    opt_stoch_policy = ape.compute_optimal_stoch_policy(mdp_distr_copy)

    # Avg MDP
    avg_mdp = ape.compute_avg_mdp(mdp_distr)
    avg_mdp_vi = ValueIteration(avg_mdp,
                                delta=0.001,
                                max_iterations=1000,
                                sample_rate=5)
    iters, value = avg_mdp_vi.run_vi()

    # Make agents.

    # Q Learning
    ql_agent = QLearnerAgent(actions)
    shaped_ql_agent_prior = ShapedQAgent(shaping_policy=opt_stoch_policy,
                                         actions=actions,
                                         name="Prior-QLearning")
    shaped_ql_agent_avgmdp = ShapedQAgent(shaping_policy=avg_mdp_vi.policy,
                                          actions=actions,
                                          name="AvgMDP-QLearning")

    # RMax
    rmax_agent = RMaxAgent(actions)
    shaped_rmax_agent_prior = ShapedRMaxAgent(
        shaping_policy=opt_stoch_policy,
        state_space=avg_mdp_vi.get_states(),
        actions=actions,
        name="Prior-RMax")
    shaped_rmax_agent_avgmdp = ShapedRMaxAgent(
        shaping_policy=avg_mdp_vi.policy,
        state_space=avg_mdp_vi.get_states(),
        actions=actions,
        name="AvgMDP-RMax")
    prune_rmax_agent = PruneRMaxAgent(mdp_distr=mdp_distr)

    if agent_type == "rmax":
        agents = [
            rmax_agent, shaped_rmax_agent_prior, shaped_rmax_agent_avgmdp,
            prune_rmax_agent
        ]
    else:
        agents = [ql_agent, shaped_ql_agent_prior, shaped_ql_agent_avgmdp]

    # Run task.
    run_agents_multi_task(agents,
                          mdp_distr,
                          task_samples=samples,
                          episodes=1,
                          steps=200,
                          is_rec_disc_reward=False,
                          verbose=True)
コード例 #33
0
ファイル: sa_helpers.py プロジェクト: david-abel/simple_rl
def make_singletask_sa(mdp, indic_func, state_class, epsilon=0.0, aa_single_act=False, prob_of_mdp=1.0, track_act_opt_pr=False):
    '''
    Args:
        mdp (MDP)
        indic_func (S x S --> {0,1})
        state_class (Class)
        epsilon (float)

    Returns:
        (StateAbstraction)
    '''

    print("\tRunning VI...",)
    sys.stdout.flush()
    # Run VI
    if isinstance(mdp, MDPDistribution):
        mdp = mdp.sample()

    vi = ValueIteration(mdp)
    iters, val = vi.run_vi()
    print(" done.")

    print("\tMaking state abstraction...",)
    sys.stdout.flush()
    sa = StateAbstraction(phi={}, state_class=state_class, track_act_opt_pr=track_act_opt_pr)
    clusters = defaultdict(list)
    num_states = len(vi.get_states())

    actions = mdp.get_actions()
    # Find state pairs that satisfy the condition.
    for i, state_x in enumerate(vi.get_states()):
        sys.stdout.flush()
        clusters[state_x] = [state_x]

        for state_y in vi.get_states()[i:]:
            if not (state_x == state_y) and indic_func(state_x, state_y, vi, actions, epsilon=epsilon):
                clusters[state_x].append(state_y)
                clusters[state_y].append(state_x)

    print("making clusters...",)
    sys.stdout.flush()
    
    # Build SA.
    for i, state in enumerate(clusters.keys()):
        new_cluster = clusters[state]
        sa.make_cluster(new_cluster)

        # Destroy old so we don't double up.
        for s in clusters[state]:
            if s in clusters.keys():
                clusters.pop(s)
    
    if aa_single_act:
        # Put all optimal actions in a set associated with the ground state.
        for ground_s in sa.get_ground_states():
            a_star_set = set(vi.get_max_q_actions(ground_s))
            sa.set_actions_state_opt_dict(ground_s, a_star_set, prob_of_mdp)

    print(" done.")
    print("\tGround States:", num_states)
    print("\tAbstract:", sa.get_num_abstr_states())
    print()

    return sa
コード例 #34
0
def main(open_plot=True):
    episodes = 100
    steps = 100
    gamma = 0.95

    mdp_class, is_goal_terminal, samples, alg = parse_args()

    # Setup multitask setting.
    mdp_distr = make_mdp_distr(mdp_class=mdp_class,
                               is_goal_terminal=is_goal_terminal,
                               gamma=gamma)
    actions = mdp_distr.get_actions()

    # Compute average MDP.
    print("Making and solving avg MDP...", end='')
    sys.stdout.flush()
    avg_mdp = compute_avg_mdp(mdp_distr)
    avg_mdp_vi = ValueIteration(avg_mdp,
                                delta=0.001,
                                max_iterations=1000,
                                sample_rate=5)
    iters, value = avg_mdp_vi.run_vi()

    ### Yuu

    # transfer_fixed_agent = FixedPolicyAgent(avg_mdp_vi.policy, name="transferFixed")
    rand_agent = RandomAgent(actions, name="$\\pi^u$")

    opt_q_func = compute_optimistic_q_function(mdp_distr)
    avg_q_func = get_q_func(avg_mdp_vi)

    best_v = -100  # Maximum possible value an agent can get in the environment.
    for x in opt_q_func:
        for y in opt_q_func[x]:
            best_v = max(best_v, opt_q_func[x][y])
    print("Vmax =", best_v)
    vmax = best_v

    vmax_func = defaultdict(lambda: defaultdict(lambda: vmax))

    if alg == "q":
        eps = 0.1
        lrate = 0.1
        pure_ql_agent = QLearningAgent(actions,
                                       gamma=gamma,
                                       alpha=lrate,
                                       epsilon=eps,
                                       name="Q-0")
        pure_ql_agent_opt = QLearningAgent(actions,
                                           gamma=gamma,
                                           alpha=lrate,
                                           epsilon=eps,
                                           default_q=vmax,
                                           name="Q-Vmax")
        ql_agent_upd_maxq = UpdatingQLearnerAgent(actions,
                                                  alpha=lrate,
                                                  epsilon=eps,
                                                  gamma=gamma,
                                                  default_q=vmax,
                                                  name="Q-MaxQInit")

        transfer_ql_agent_optq = QLearningAgent(actions,
                                                gamma=gamma,
                                                alpha=lrate,
                                                epsilon=eps,
                                                name="Q-UO")
        transfer_ql_agent_optq.set_init_q_function(opt_q_func)

        transfer_ql_agent_avgq = QLearningAgent(actions,
                                                gamma=gamma,
                                                alpha=lrate,
                                                epsilon=eps,
                                                name="Q-AverageQInit")
        transfer_ql_agent_avgq.set_init_q_function(avg_q_func)

        agents = [
            transfer_ql_agent_optq, ql_agent_upd_maxq, transfer_ql_agent_avgq,
            pure_ql_agent_opt, pure_ql_agent
        ]
    elif alg == "rmax":
        """
        Note that Rmax is a model-based algorithm and is very slow compared to other model-free algorithms like Q-learning and delayed Q-learning.
        """
        known_threshold = 10
        min_experience = 5
        pure_rmax_agent = RMaxAgent(actions,
                                    gamma=gamma,
                                    horizon=known_threshold,
                                    s_a_threshold=min_experience,
                                    name="RMAX-Vmax")
        updating_trans_rmax_agent = UpdatingRMaxAgent(
            actions,
            gamma=gamma,
            horizon=known_threshold,
            s_a_threshold=min_experience,
            name="RMAX-MaxQInit")
        trans_rmax_agent = RMaxAgent(actions,
                                     gamma=gamma,
                                     horizon=known_threshold,
                                     s_a_threshold=min_experience,
                                     name="RMAX-UO")
        trans_rmax_agent.set_init_q_function(opt_q_func)
        agents = [
            trans_rmax_agent, updating_trans_rmax_agent, pure_rmax_agent,
            rand_agent
        ]
    elif alg == "delayed-q":
        torelance = 0.1
        min_experience = 5
        pure_delayed_ql_agent = DelayedQAgent(actions,
                                              gamma=gamma,
                                              m=min_experience,
                                              epsilon1=torelance,
                                              name="DelayedQ-Vmax")
        pure_delayed_ql_agent.set_q_function(vmax_func)
        updating_delayed_ql_agent = UpdatingDelayedQLearningAgent(
            actions,
            default_q=vmax,
            gamma=gamma,
            m=min_experience,
            epsilon1=torelance,
            name="DelayedQ-MaxQInit")
        updating_delayed_ql_agent.set_q_function(vmax_func)
        trans_delayed_ql_agent = DelayedQAgent(actions,
                                               gamma=gamma,
                                               m=min_experience,
                                               epsilon1=torelance,
                                               name="DelayedQ-UO")
        trans_delayed_ql_agent.set_q_function(opt_q_func)

        agents = [
            pure_delayed_ql_agent, updating_delayed_ql_agent,
            trans_delayed_ql_agent, rand_agent
        ]
        # agents = [updating_delayed_ql_agent, trans_delayed_ql_agent, rand_agent]
    elif alg == "sample-effect":
        """
        This runs a comparison of MaxQInit with different number of MDP samples to calculate the initial Q function. Note that the performance of the sampled MDP is ignored for this experiment. It reproduces the result of Figure 4 of "Policy and Value Transfer for Lifelong Reinforcement Learning".
        """
        torelance = 0.1
        min_experience = 5
        pure_delayed_ql_agent = DelayedQAgent(actions,
                                              opt_q_func,
                                              m=min_experience,
                                              epsilon1=torelance,
                                              name="DelayedQ-Vmax")
        pure_delayed_ql_agent.set_vmax()
        dql_60samples = UpdatingDelayedQLearningAgent(
            actions,
            default_q=vmax,
            gamma=gamma,
            m=min_experience,
            epsilon1=torelance,
            num_sample_tasks=60,
            name="$DelayedQ-MaxQInit60$")
        dql_40samples = UpdatingDelayedQLearningAgent(
            actions,
            default_q=vmax,
            gamma=gamma,
            m=min_experience,
            epsilon1=torelance,
            num_sample_tasks=40,
            name="$DelayedQ-MaxQInit40$")
        dql_20samples = UpdatingDelayedQLearningAgent(
            actions,
            default_q=vmax,
            gamma=gamma,
            m=min_experience,
            epsilon1=torelance,
            num_sample_tasks=20,
            name="$DelayedQ-MaxQInit20$")

        # Sample MDPs. Note that the performance of the sampled MDP is ignored and not included in the average in the final plot.
        run_agents_lifelong([dql_20samples],
                            mdp_distr,
                            samples=int(samples * 1 / 5.0),
                            episodes=episodes,
                            steps=steps,
                            reset_at_terminal=is_goal_terminal,
                            track_disc_reward=False,
                            cumulative_plot=True,
                            open_plot=open_plot)
        # mdp_distr.reset_tasks()
        run_agents_lifelong([dql_40samples],
                            mdp_distr,
                            samples=int(samples * 2 / 5.0),
                            episodes=episodes,
                            steps=steps,
                            reset_at_terminal=is_goal_terminal,
                            track_disc_reward=False,
                            cumulative_plot=True,
                            open_plot=open_plot)
        # mdp_distr.reset_tasks()
        run_agents_lifelong([dql_60samples],
                            mdp_distr,
                            samples=int(samples * 3 / 5.0),
                            episodes=episodes,
                            steps=steps,
                            reset_at_terminal=is_goal_terminal,
                            track_disc_reward=False,
                            cumulative_plot=True,
                            open_plot=open_plot)
        # mdp_distr.reset_tasks()
        # agents = [pure_delayed_ql_agent]
        agents = [
            dql_60samples, dql_40samples, dql_20samples, pure_delayed_ql_agent
        ]
    else:
        msg = "Unknown type of agent:" + alg + ". Use -agent_type (q, rmax, delayed-q)"
        assert False, msg

    # Run task.
    run_agents_lifelong(agents,
                        mdp_distr,
                        samples=samples,
                        episodes=episodes,
                        steps=steps,
                        reset_at_terminal=is_goal_terminal,
                        track_disc_reward=False,
                        cumulative_plot=True,
                        open_plot=open_plot)
コード例 #35
0
class BeliefUpdater(object):
    ''' Wrapper class for different methods for belief state updates in POMDPs. '''
    def __init__(self,
                 mdp,
                 transition_func,
                 reward_func,
                 observation_func,
                 updater_type='discrete'):
        '''
        Args:
            mdp (POMDP)
            transition_func: T(s, a) --> s'
            reward_func: R(s, a) --> float
            observation_func: O(s, a) --> z
            updater_type (str)
        '''
        self.reward_func = reward_func
        self.updater_type = updater_type

        # We use the ValueIteration class to construct the transition and observation probabilities
        self.vi = ValueIteration(mdp, sample_rate=500)

        self.transition_probs = self.construct_transition_matrix(
            transition_func)
        self.observation_probs = self.construct_observation_matrix(
            observation_func, transition_func)

        if updater_type == 'discrete':
            self.updater = self.discrete_filter_updater
        elif updater_type == 'kalman':
            self.updater = self.kalman_filter_updater
        elif updater_type == 'particle':
            self.updater = self.particle_filter_updater
        else:
            raise AttributeError(
                'updater_type {} did not conform to expected type'.format(
                    updater_type))

    def discrete_filter_updater(self, belief, action, observation):
        def _compute_normalization_factor(bel):
            return sum(bel.values())

        def _update_belief_for_state(b, sp, T, O, a, z):
            return O[sp][z] * sum([T[s][a][sp] * b[s] for s in b])

        new_belief = defaultdict()
        for sprime in belief:
            new_belief[sprime] = _update_belief_for_state(
                belief, sprime, self.transition_probs, self.observation_probs,
                action, observation)

        normalization = _compute_normalization_factor(new_belief)

        for sprime in belief:
            if normalization > 0: new_belief[sprime] /= normalization

        return new_belief

    def kalman_filter_updater(self, belief, action, observation):
        pass

    def particle_filter_updater(self, belief, action, observation):
        pass

    def construct_transition_matrix(self, transition_func):
        '''
        Create an MLE of the transition probabilities by sampling from the transition_func
        multiple times.
        Args:
            transition_func: T(s, a) -> s'

        Returns:
            transition_probabilities (defaultdict): T(s, a, s') --> float
        '''
        self.vi._compute_matrix_from_trans_func()
        return self.vi.trans_dict

    def construct_observation_matrix(self, observation_func, transition_func):
        '''
        Create an MLE of the observation probabilities by sampling from the observation_func
        multiple times.
        Args:
            observation_func: O(s) -> z
            transition_func: T(s, a) -> s'

        Returns:
            observation_probabilities (defaultdict): O(s, z) --> float
        '''
        def normalize_probabilities(odict):
            norm_factor = sum(odict.values())
            for obs in odict:
                odict[obs] /= norm_factor
            return odict

        obs_dict = defaultdict(lambda: defaultdict(float))
        for state in self.vi.get_states():
            for action in self.vi.mdp.actions:
                for sample in range(self.vi.sample_rate):
                    observation = observation_func(state, action)
                    next_state = transition_func(state, action)
                    obs_dict[next_state][
                        observation] += 1. / self.vi.sample_rate
        for state in self.vi.get_states():
            obs_dict[state] = normalize_probabilities(obs_dict[state])
        return obs_dict
コード例 #36
0
def make_singletask_sa(mdp,
                       indic_func,
                       state_class,
                       epsilon=0.0,
                       aa_single_act=False,
                       prob_of_mdp=1.0,
                       track_act_opt_pr=False):
    '''
    Args:
        mdp (MDP)
        indic_func (S x S --> {0,1})
        state_class (Class)
        epsilon (float)

    Returns:
        (StateAbstraction)
    '''

    print("\tRunning VI...", )
    sys.stdout.flush()
    # Run VI
    if isinstance(mdp, MDPDistribution):
        mdp = mdp.sample()

    vi = ValueIteration(mdp)
    iters, val = vi.run_vi()
    print(" done.")

    print("\tMaking state abstraction...", )
    sys.stdout.flush()
    sa = StateAbstraction(phi={},
                          state_class=state_class,
                          track_act_opt_pr=track_act_opt_pr)
    clusters = defaultdict(list)
    num_states = len(vi.get_states())

    actions = mdp.get_actions()
    # Find state pairs that satisfy the condition.
    for i, state_x in enumerate(vi.get_states()):
        sys.stdout.flush()
        clusters[state_x] = [state_x]

        for state_y in vi.get_states()[i:]:
            if not (state_x == state_y) and indic_func(
                    state_x, state_y, vi, actions, epsilon=epsilon):
                clusters[state_x].append(state_y)
                clusters[state_y].append(state_x)

    print("making clusters...", )
    sys.stdout.flush()

    # Build SA.
    for i, state in enumerate(clusters.keys()):
        new_cluster = clusters[state]
        sa.make_cluster(new_cluster)

        # Destroy old so we don't double up.
        for s in clusters[state]:
            if s in clusters.keys():
                clusters.pop(s)

    if aa_single_act:
        # Put all optimal actions in a set associated with the ground state.
        for ground_s in sa.get_ground_states():
            a_star_set = set(vi.get_max_q_actions(ground_s))
            sa.set_actions_state_opt_dict(ground_s, a_star_set, prob_of_mdp)

    print(" done.")
    print("\tGround States:", num_states)
    print("\tAbstract:", sa.get_num_abstr_states())
    print()

    return sa
コード例 #37
0
class BeliefUpdater(object):
    ''' Wrapper class for different methods for belief state updates in POMDPs. '''

    def __init__(self, mdp, transition_func, reward_func, observation_func, updater_type='discrete'):
        '''
        Args:
            mdp (POMDP)
            transition_func: T(s, a) --> s'
            reward_func: R(s, a) --> float
            observation_func: O(s, a) --> z
            updater_type (str)
        '''
        self.reward_func = reward_func
        self.updater_type = updater_type

        # We use the ValueIteration class to construct the transition and observation probabilities
        self.vi = ValueIteration(mdp, sample_rate=500)

        self.transition_probs = self.construct_transition_matrix(transition_func)
        self.observation_probs = self.construct_observation_matrix(observation_func, transition_func)

        if updater_type == 'discrete':
            self.updater = self.discrete_filter_updater
        elif updater_type == 'kalman':
            self.updater = self.kalman_filter_updater
        elif updater_type == 'particle':
            self.updater = self.particle_filter_updater
        else:
            raise AttributeError('updater_type {} did not conform to expected type'.format(updater_type))

    def discrete_filter_updater(self, belief, action, observation):
        def _compute_normalization_factor(bel):
            return sum(bel.values())

        def _update_belief_for_state(b, sp, T, O, a, z):
            return O[sp][z] * sum([T[s][a][sp] * b[s] for s in b])

        new_belief = defaultdict()
        for sprime in belief:
            new_belief[sprime] = _update_belief_for_state(belief, sprime, self.transition_probs, self.observation_probs, action, observation)

        normalization = _compute_normalization_factor(new_belief)

        for sprime in belief:
            if normalization > 0: new_belief[sprime] /= normalization

        return new_belief

    def kalman_filter_updater(self, belief, action, observation):
        pass

    def particle_filter_updater(self, belief, action, observation):
        pass

    def construct_transition_matrix(self, transition_func):
        '''
        Create an MLE of the transition probabilities by sampling from the transition_func
        multiple times.
        Args:
            transition_func: T(s, a) -> s'

        Returns:
            transition_probabilities (defaultdict): T(s, a, s') --> float
        '''
        self.vi._compute_matrix_from_trans_func()
        return self.vi.trans_dict

    def construct_observation_matrix(self, observation_func, transition_func):
        '''
        Create an MLE of the observation probabilities by sampling from the observation_func
        multiple times.
        Args:
            observation_func: O(s) -> z
            transition_func: T(s, a) -> s'

        Returns:
            observation_probabilities (defaultdict): O(s, z) --> float
        '''
        def normalize_probabilities(odict):
            norm_factor = sum(odict.values())
            for obs in odict:
                odict[obs] /= norm_factor
            return odict

        obs_dict = defaultdict(lambda:defaultdict(float))
        for state in self.vi.get_states():
            for action in self.vi.mdp.actions:
                for sample in range(self.vi.sample_rate):
                    observation = observation_func(state, action)
                    next_state = transition_func(state, action)
                    obs_dict[next_state][observation] += 1. / self.vi.sample_rate
        for state in self.vi.get_states():
            obs_dict[state] = normalize_probabilities(obs_dict[state])
        return obs_dict