コード例 #1
0
ファイル: info_sa.py プロジェクト: apragupta/IB_SA_simple_rl
def info_sa_visualize_abstr(mdp, demo_policy_lambda, beta=2.0, is_deterministic_ib=False, is_agent_in_control=False):
    '''
    Args:
        mdp (simple_rl.MDP)
        demo_policy_lambda (lambda : simple_rl.State --> str)
        beta (float)
        is_deterministic_ib (bool)
        is_agent_in_control (bool)

    Summary:
        Visualizes the state abstraction found by info_sa using pygame.
    '''
    if is_agent_in_control:
        # Run info_sa with the agent controlling the MDP.
        pmf_s_phi, phi_pmf, abstr_policy_pmf = agent_in_control.run_agent_in_control_info_sa(mdp, demo_policy_lambda, rounds=100, iters=500, beta=beta, is_deterministic_ib=is_deterministic_ib)
    else:
        # Run info_sa.
        pmf_s_phi, phi_pmf, abstr_policy_pmf = run_info_sa(mdp, demo_policy_lambda, iters=500, beta=beta, convergence_threshold=0.00001, is_deterministic_ib=is_deterministic_ib)

    lambda_abstr_policy = get_lambda_policy(abstr_policy_pmf)
    prob_s_phi = ProbStateAbstraction(phi_pmf)
    crisp_s_phi = convert_prob_sa_to_sa(prob_s_phi)

    vi = ValueIteration(mdp)
    print "\t|S|", vi.get_num_states()
    print "\t|S_\\phi|_crisp =", crisp_s_phi.get_num_abstr_states()

    from simple_rl.abstraction.state_abs.sa_helpers import visualize_state_abstr_grid
    visualize_state_abstr_grid(mdp, crisp_s_phi)
コード例 #2
0
ファイル: info_sa.py プロジェクト: apragupta/IB_SA_simple_rl
def info_sa_compare_policies(mdp, demo_policy_lambda, beta=3.0, is_deterministic_ib=False, is_agent_in_control=False):
    '''
    Args:
        mdp (simple_rl.MDP)
        demo_policy_lambda (lambda : simple_rl.State --> str)
        beta (float)
        is_deterministic_ib (bool): If True, run DIB, else IB.
        is_agent_in_control (bool): If True, runs the DIB in agent_in_control.py instead.

    Summary:
        Runs info_sa and compares the value of the found policy with the demonstrator policy.
    '''
    if is_agent_in_control:
        # Run info_sa with the agent controlling the MDP.
        pmf_s_phi, phi_pmf, abstr_policy_pmf = agent_in_control.run_agent_in_control_info_sa(mdp, demo_policy_lambda, rounds=100, iters=500, beta=beta, is_deterministic_ib=is_deterministic_ib)
    else:
        # Run info_sa.
        pmf_s_phi, phi_pmf, abstr_policy_pmf = run_info_sa(mdp, demo_policy_lambda, iters=500, beta=beta, convergence_threshold=0.00001, is_deterministic_ib=is_deterministic_ib)

    # Make demonstrator agent and random agent.
    demo_agent = FixedPolicyAgent(demo_policy_lambda, name="$\\pi_d$")
    rand_agent = RandomAgent(mdp.get_actions(), name="$\\pi_u$")

    # Make abstract agent.
    lambda_abstr_policy = get_lambda_policy(abstr_policy_pmf)
    prob_s_phi = ProbStateAbstraction(phi_pmf)
    crisp_s_phi = convert_prob_sa_to_sa(prob_s_phi)
    abstr_agent = AbstractionWrapper(FixedPolicyAgent, state_abstr=crisp_s_phi, agent_params={"policy":lambda_abstr_policy, "name":"$\\pi_\\phi$"}, name_ext="")
    
    # Run.
    run_agents_on_mdp([demo_agent, abstr_agent, rand_agent], mdp, episodes=1, steps=1000)


    non_zero_abstr_states = [x for x in pmf_s_phi.values() if x > 0]
    # Print state space sizes.
    demo_vi = ValueIteration(mdp)
    print "\nState Spaces Sizes:"
    print "\t|S| =", demo_vi.get_num_states()
    print "\tH(S_\\phi) =", entropy(pmf_s_phi)
    print "\t|S_\\phi|_crisp =", crisp_s_phi.get_num_abstr_states()
    print "\tdelta_min =", min(non_zero_abstr_states)
    print "\tnum non zero states =", len(non_zero_abstr_states)
    print
コード例 #3
0
def make_info_sa_val_and_size_plots(mdp,
                                    demo_policy_lambda,
                                    beta_range,
                                    results_dir="info_sa_results",
                                    instances=3,
                                    include_stoch=False,
                                    is_agent_in_control=False):
    '''
    Args:
        mdp (simple_rl.MDP)
        demo_policy_lambda (lambda : simple_rl.State --> str)
        beta_range (list)
        results_dir (str)
        instances (int)
        include_stoch (bool): If True, also runs IB.
        is_agent_in_control (bool): If True, runs the agent_in_control.py variant of DIB-SA.

    Summary:
        Main plotting function for info_sa experiments.
    '''
    # Clear old results.
    all_policies = ["demo_val", "dibs_val", "dibs_states", "etad_states"]
    if include_stoch:
        all_policies += ["ib_val", "ib_states"]
    for policy in all_policies:
        if os.path.exists(os.path.join(results_dir, str(policy)) + ".csv"):
            os.remove(os.path.join(results_dir, str(policy)) + ".csv")

    # Set relevant params.
    param_dict = {
        "mdp": mdp,
        "iters": 500,
        "convergence_threshold": 0.0001,
        "demo_policy_lambda": demo_policy_lambda,
        "is_agent_in_control": is_agent_in_control
    }

    # Record vallue of demo policy and size of ground state space.
    demo_agent = FixedPolicyAgent(demo_policy_lambda)
    demo_val = evaluate_agent(demo_agent, mdp, instances=100)
    vi = ValueIteration(mdp)
    num_ground_states = vi.get_num_states()
    for beta in beta_range:
        write_datum_to_file(file_name="demo_val",
                            datum=demo_val,
                            extra_dir=results_dir)
        write_datum_to_file(file_name="ground_states",
                            datum=num_ground_states,
                            extra_dir=results_dir)

    # Run core algorithm for DIB and IB.
    for instance in range(instances):
        print "\nInstance", instance + 1, "of", str(instances) + "."
        random.jumpahead(1)

        # For each beta.
        for beta in beta_range:

            # Run DIB.
            dibs_val, dibs_states = _info_sa_val_and_size_plot_wrapper(
                beta=beta,
                param_dict=dict(param_dict.items() + {
                    "is_deterministic_ib": True,
                    "use_crisp_policy": False
                }.items()))
            write_datum_to_file(file_name="dibs_val",
                                datum=dibs_val,
                                extra_dir=results_dir)
            write_datum_to_file(file_name="dibs_states",
                                datum=dibs_states,
                                extra_dir=results_dir)

            if include_stoch:
                ib_val, ib_states = _info_sa_val_and_size_plot_wrapper(
                    beta=beta,
                    param_dict=dict(param_dict.items() + {
                        "is_deterministic_ib": False,
                        "use_crisp_policy": False
                    }.items()))
                write_datum_to_file(file_name="ib_val",
                                    datum=ib_val,
                                    extra_dir=results_dir)
                write_datum_to_file(file_name="ib_states",
                                    datum=ib_states,
                                    extra_dir=results_dir)

        # End instances.
        end_of_instance("dibs_val", extra_dir=results_dir)
        end_of_instance("dibs_states", extra_dir=results_dir)
        if include_stoch:
            end_of_instance("ib_val", extra_dir=results_dir)
            end_of_instance("ib_states", extra_dir=results_dir)

    beta_range_file = file(os.path.join(results_dir, "beta_range.csv"), "w")
    for beta in beta_range:
        beta_range_file.write(str(beta))
        beta_range_file.write(",")

    beta_range_file.close()

    make_beta_val_plot([p for p in all_policies if "val" in p],
                       results_dir,
                       is_agent_in_control=is_agent_in_control)