Exemple #1
0
    def get_optimal_policy(self, ent_wt=0.1, gamma=0.9):
        q_fn = q_iteration.q_iteration(self, ent_wt=ent_wt, gamma=gamma, K=100)
        policy = CategoricalSoftQPolicy(
            env_spec=self.spec, ent_wt=ent_wt, hardcoded_q=q_fn)
        policy.set_q_func(q_fn)

        #q_fn2 = 1.0/ent_wt * q_fn
        #probs = np.exp(q_fn2)/np.sum(np.exp(q_fn2), axis=1, keepdims=True)
        #import pdb; pdb.set_trace()

        return policy
Exemple #2
0
def tabular_maxent_irl(env,
                       demo_visitations,
                       num_itrs=50,
                       ent_wt=1.0,
                       lr=1e-3,
                       state_only=False,
                       discount=0.99,
                       T=5):
    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim

    # Initialize policy and reward function
    reward_fn = np.zeros((dim_obs, dim_act))
    q_rew = np.zeros((dim_obs, dim_act))

    update = adam_optimizer(lr)

    for it in TrainingIterator(num_itrs, heartbeat=1.0):
        q_itrs = 20 if it.itr > 5 else 100
        ### compute policy in closed form
        q_rew = q_iteration(env,
                            reward_matrix=reward_fn,
                            ent_wt=ent_wt,
                            warmstart_q=q_rew,
                            K=q_itrs,
                            gamma=discount)

        ### update reward
        # need to count how often the policy will visit a particular (s, a) pair
        pol_visitations = compute_visitation(env,
                                             q_rew,
                                             ent_wt=ent_wt,
                                             T=T,
                                             discount=discount)

        grad = -(demo_visitations - pol_visitations)
        it.record('VisitationInfNorm', np.max(np.abs(grad)))
        if state_only:
            grad = np.sum(grad, axis=1, keepdims=True)
        reward_fn = update(reward_fn, grad)

        if it.heartbeat:
            print(it.itr_message())
            print('\t', it.pop_mean('VisitationInfNorm'))
    return reward_fn, q_rew
Exemple #3
0
    import pdb
    pdb.set_trace()


if __name__ == "__main__":
    # test IRL
    np.set_printoptions(suppress=True)
    env = random_env(16, 4, seed=1, terminate=False, t_sparsity=0.8)
    env2 = random_env(16, 4, seed=2, terminate=False, t_sparsity=0.8)
    dS = env.spec.observation_space.flat_dim
    dU = env.spec.action_space.flat_dim
    dO = 8
    ent_wt = 0.5
    discount = 0.7
    obs_matrix = np.random.randn(dS, dO)
    true_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount)
    true_sa_visits = compute_visitation(env,
                                        true_q,
                                        ent_wt=ent_wt,
                                        T=5,
                                        discount=discount)
    expert_pol = get_policy(true_q, ent_wt=ent_wt)

    if True:
        learned_rew, learned_q = tabular_maxent_irl(env,
                                                    true_sa_visits,
                                                    lr=0.01,
                                                    num_itrs=1000,
                                                    ent_wt=ent_wt,
                                                    state_only=True,
                                                    discount=discount)
Exemple #4
0
def tabular_gcl_irl(env,
                    demo_visitations,
                    irl_model,
                    num_itrs=50,
                    ent_wt=1.0,
                    lr=1e-3,
                    state_only=False,
                    discount=0.99,
                    batch_size=20024):
    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim

    states_all = []
    actions_all = []
    for s in range(dim_obs):
        for a in range(dim_act):
            states_all.append(flat_to_one_hot(s, dim_obs))
            actions_all.append(flat_to_one_hot(a, dim_act))
    states_all = np.array(states_all)
    actions_all = np.array(actions_all)
    path_all = {'observations': states_all, 'actions': actions_all}

    # Initialize policy and reward function
    reward_fn = np.zeros((dim_obs, dim_act))
    q_rew = np.zeros((dim_obs, dim_act))

    update = adam_optimizer(lr)

    for it in TrainingIterator(num_itrs, heartbeat=1.0):
        q_itrs = 20 if it.itr > 5 else 100
        ### compute policy in closed form
        q_rew = q_iteration(env,
                            reward_matrix=reward_fn,
                            ent_wt=ent_wt,
                            warmstart_q=q_rew,
                            K=q_itrs,
                            gamma=discount)
        pol_rew = get_policy(q_rew, ent_wt=ent_wt)

        ### update reward
        # need to count how often the policy will visit a particular (s, a) pair
        pol_visitations = compute_visitation(env,
                                             q_rew,
                                             ent_wt=ent_wt,
                                             T=5,
                                             discount=discount)

        # now we need to sample states and actions, and give them to the discriminator
        demo_path = sample_states(env, q_rew, demo_visitations, batch_size,
                                  ent_wt)
        irl_model.set_demos([demo_path])
        path = sample_states(env, q_rew, pol_visitations, batch_size, ent_wt)
        irl_model.fit([path],
                      policy=pol_rew,
                      max_itrs=200,
                      lr=1e-3,
                      batch_size=1024)

        rew_stack = irl_model.eval([path_all])[0]
        reward_fn = np.zeros_like(q_rew)
        i = 0
        for s in range(dim_obs):
            for a in range(dim_act):
                reward_fn[s, a] = rew_stack[i]
                i += 1

        diff_visit = np.abs(demo_visitations - pol_visitations)
        it.record('VisitationDiffInfNorm', np.max(diff_visit))
        it.record('VisitationDiffAvg', np.mean(diff_visit))

        if it.heartbeat:
            print(it.itr_message())
            print('\tVisitationDiffInfNorm:',
                  it.pop_mean('VisitationDiffInfNorm'))
            print('\tVisitationDiffAvg:', it.pop_mean('VisitationDiffAvg'))

            print('visitations', pol_visitations)
            print('diff_visit', diff_visit)
            adjusted_rew = reward_fn - np.mean(reward_fn) + np.mean(
                env.rew_matrix)
            print('adjusted_rew', adjusted_rew)
    return reward_fn, q_rew
Exemple #5
0
if __name__ == "__main__":
    # test IRL
    from inverse_rl.envs.tabular.q_iteration import q_iteration
    from inverse_rl.envs.tabular.simple_env import random_env
    from inverse_rl.utils.plotter import TabularPlotter
    np.set_printoptions(suppress=True)
    env = random_env(16, 4, seed=1, terminate=False, t_sparsity=0.8)
    env2 = random_env(16, 4, seed=2, terminate=False, t_sparsity=0.8)
    #plotter = TabularPlotter(4, 16, invert_y=True, text_values=False)
    dS = env.spec.observation_space.flat_dim
    dU = env.spec.action_space.flat_dim
    dO = 8
    ent_wt = 0.5
    discount = 0.7
    obs_matrix = np.random.randn(dS, dO)
    true_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount)
    true_sa_visits = compute_visitation(env,
                                        true_q,
                                        ent_wt=ent_wt,
                                        T=5,
                                        discount=discount)
    expert_pol = get_policy(true_q, ent_wt=ent_wt)

    if True:
        learned_rew, learned_q = tabular_maxent_irl(env,
                                                    true_sa_visits,
                                                    lr=0.01,
                                                    num_itrs=1000,
                                                    ent_wt=ent_wt,
                                                    state_only=False,
                                                    discount=discount)