Exemplo n.º 1
0
def sample_states(env, q_fn, visitation_probs, n_sample, ent_wt):
    dS, dA = visitation_probs.shape
    samples = np.random.choice(np.arange(dS * dA),
                               size=n_sample,
                               p=visitation_probs.reshape(dS * dA))
    policy = get_policy(q_fn, ent_wt=ent_wt)
    observations = samples // dA
    actions = samples % dA
    a_logprobs = np.log(policy[observations, actions])

    observations_next = []
    for i in range(n_sample):
        t_distr = env.tabular_trans_distr(observations[i], actions[i])
        next_state = flat_to_one_hot(np.random.choice(np.arange(len(t_distr)),
                                                      p=t_distr),
                                     ndim=dS)
        observations_next.append(next_state)
    observations_next = np.array(observations_next)

    return {
        'observations': flat_to_one_hot(observations, ndim=dS),
        'actions': flat_to_one_hot(actions, ndim=dA),
        'a_logprobs': a_logprobs,
        'observations_next': observations_next
    }
Exemplo n.º 2
0
def compute_visitation(env, q_fn, ent_wt=1.0, T=50, discount=1.0):
    pol_probs = get_policy(q_fn, ent_wt=ent_wt)

    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim
    state_visitation = np.expand_dims(env.initial_state_distribution, axis=1)
    t_matrix = env.transition_matrix  # S x A x S
    sa_visit_t = np.zeros((dim_obs, dim_act, T))

    for i in range(T):
        sa_visit = state_visitation * pol_probs
        sa_visit_t[:, :, i] = sa_visit  #(discount**i) * sa_visit
        # sum-out (SA)S
        new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
        state_visitation = np.expand_dims(new_state_visitation, axis=1)
    return np.sum(sa_visit_t, axis=2) / float(T)
Exemplo n.º 3
0
def tabular_gcl_irl(env,
                    demo_visitations,
                    irl_model,
                    num_itrs=50,
                    ent_wt=1.0,
                    lr=1e-3,
                    state_only=False,
                    discount=0.99,
                    batch_size=20024):
    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim

    states_all = []
    actions_all = []
    for s in range(dim_obs):
        for a in range(dim_act):
            states_all.append(flat_to_one_hot(s, dim_obs))
            actions_all.append(flat_to_one_hot(a, dim_act))
    states_all = np.array(states_all)
    actions_all = np.array(actions_all)
    path_all = {'observations': states_all, 'actions': actions_all}

    # Initialize policy and reward function
    reward_fn = np.zeros((dim_obs, dim_act))
    q_rew = np.zeros((dim_obs, dim_act))

    update = adam_optimizer(lr)

    for it in TrainingIterator(num_itrs, heartbeat=1.0):
        q_itrs = 20 if it.itr > 5 else 100
        ### compute policy in closed form
        q_rew = q_iteration(env,
                            reward_matrix=reward_fn,
                            ent_wt=ent_wt,
                            warmstart_q=q_rew,
                            K=q_itrs,
                            gamma=discount)
        pol_rew = get_policy(q_rew, ent_wt=ent_wt)

        ### update reward
        # need to count how often the policy will visit a particular (s, a) pair
        pol_visitations = compute_visitation(env,
                                             q_rew,
                                             ent_wt=ent_wt,
                                             T=5,
                                             discount=discount)

        # now we need to sample states and actions, and give them to the discriminator
        demo_path = sample_states(env, q_rew, demo_visitations, batch_size,
                                  ent_wt)
        irl_model.set_demos([demo_path])
        path = sample_states(env, q_rew, pol_visitations, batch_size, ent_wt)
        irl_model.fit([path],
                      policy=pol_rew,
                      max_itrs=200,
                      lr=1e-3,
                      batch_size=1024)

        rew_stack = irl_model.eval([path_all])[0]
        reward_fn = np.zeros_like(q_rew)
        i = 0
        for s in range(dim_obs):
            for a in range(dim_act):
                reward_fn[s, a] = rew_stack[i]
                i += 1

        diff_visit = np.abs(demo_visitations - pol_visitations)
        it.record('VisitationDiffInfNorm', np.max(diff_visit))
        it.record('VisitationDiffAvg', np.mean(diff_visit))

        if it.heartbeat:
            print(it.itr_message())
            print('\tVisitationDiffInfNorm:',
                  it.pop_mean('VisitationDiffInfNorm'))
            print('\tVisitationDiffAvg:', it.pop_mean('VisitationDiffAvg'))

            print('visitations', pol_visitations)
            print('diff_visit', diff_visit)
            adjusted_rew = reward_fn - np.mean(reward_fn) + np.mean(
                env.rew_matrix)
            print('adjusted_rew', adjusted_rew)
    return reward_fn, q_rew
Exemplo n.º 4
0
    env = random_env(16, 4, seed=1, terminate=False, t_sparsity=0.8)
    env2 = random_env(16, 4, seed=2, terminate=False, t_sparsity=0.8)
    #plotter = TabularPlotter(4, 16, invert_y=True, text_values=False)
    dS = env.spec.observation_space.flat_dim
    dU = env.spec.action_space.flat_dim
    dO = 8
    ent_wt = 0.5
    discount = 0.7
    obs_matrix = np.random.randn(dS, dO)
    true_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount)
    true_sa_visits = compute_visitation(env,
                                        true_q,
                                        ent_wt=ent_wt,
                                        T=5,
                                        discount=discount)
    expert_pol = get_policy(true_q, ent_wt=ent_wt)

    if True:
        learned_rew, learned_q = tabular_maxent_irl(env,
                                                    true_sa_visits,
                                                    lr=0.01,
                                                    num_itrs=1000,
                                                    ent_wt=ent_wt,
                                                    state_only=False,
                                                    discount=discount)
        #extracted_rew = get_reward(env, learned_q, ent_wt=ent_wt, gamma=discount)
        #new_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount, reward_matrix=extracted_rew)
        learned_pol = get_policy(learned_q, ent_wt=ent_wt)

    else:
        import tensorflow as tf