def tabular_maxent_irl(env, demo_visitations, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False, discount=0.99, T=5): dim_obs = env.observation_space.flat_dim dim_act = env.action_space.flat_dim # Initialize policy and reward function reward_fn = np.zeros((dim_obs, dim_act)) q_rew = np.zeros((dim_obs, dim_act)) update = adam_optimizer(lr) for it in TrainingIterator(num_itrs, heartbeat=1.0): q_itrs = 20 if it.itr > 5 else 100 ### compute policy in closed form q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount) ### update reward # need to count how often the policy will visit a particular (s, a) pair pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=T, discount=discount) grad = -(demo_visitations - pol_visitations) it.record('VisitationInfNorm', np.max(np.abs(grad))) if state_only: grad = np.sum(grad, axis=1, keepdims=True) reward_fn = update(reward_fn, grad) if it.heartbeat: print(it.itr_message()) print('\t', it.pop_mean('VisitationInfNorm')) return reward_fn, q_rew
def tabular_gcl_irl(env, demo_visitations, irl_model, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False, discount=0.99, batch_size=20024): dim_obs = env.observation_space.flat_dim dim_act = env.action_space.flat_dim states_all = [] actions_all = [] for s in range(dim_obs): for a in range(dim_act): states_all.append(flat_to_one_hot(s, dim_obs)) actions_all.append(flat_to_one_hot(a, dim_act)) states_all = np.array(states_all) actions_all = np.array(actions_all) path_all = {'observations': states_all, 'actions': actions_all} # Initialize policy and reward function reward_fn = np.zeros((dim_obs, dim_act)) q_rew = np.zeros((dim_obs, dim_act)) update = adam_optimizer(lr) for it in TrainingIterator(num_itrs, heartbeat=1.0): q_itrs = 20 if it.itr > 5 else 100 ### compute policy in closed form q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount) pol_rew = get_policy(q_rew, ent_wt=ent_wt) ### update reward # need to count how often the policy will visit a particular (s, a) pair pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=5, discount=discount) # now we need to sample states and actions, and give them to the discriminator demo_path = sample_states(env, q_rew, demo_visitations, batch_size, ent_wt) irl_model.set_demos([demo_path]) path = sample_states(env, q_rew, pol_visitations, batch_size, ent_wt) irl_model.fit([path], policy=pol_rew, max_itrs=200, lr=1e-3, batch_size=1024) rew_stack = irl_model.eval([path_all])[0] reward_fn = np.zeros_like(q_rew) i = 0 for s in range(dim_obs): for a in range(dim_act): reward_fn[s, a] = rew_stack[i] i += 1 diff_visit = np.abs(demo_visitations - pol_visitations) it.record('VisitationDiffInfNorm', np.max(diff_visit)) it.record('VisitationDiffAvg', np.mean(diff_visit)) if it.heartbeat: print(it.itr_message()) print('\tVisitationDiffInfNorm:', it.pop_mean('VisitationDiffInfNorm')) print('\tVisitationDiffAvg:', it.pop_mean('VisitationDiffAvg')) print('visitations', pol_visitations) print('diff_visit', diff_visit) adjusted_rew = reward_fn - np.mean(reward_fn) + np.mean( env.rew_matrix) print('adjusted_rew', adjusted_rew) return reward_fn, q_rew