def test_policy_om_reasonable_mdp(): # MDP described above mdp = ReasonableMDP() # get policy etc. for our MDP V, Q, pi = mce_partition_fh(mdp) Dt, D = mce_occupancy_measures(mdp, pi=pi) assert np.all(np.isfinite(V)) assert np.all(np.isfinite(Q)) assert np.all(np.isfinite(pi)) assert np.all(np.isfinite(Dt)) assert np.all(np.isfinite(D)) # check that actions 0 & 2 (which go to states 1 & 2) are roughly equal assert np.allclose(pi[:19, 0, 0], pi[:19, 0, 2]) # also check that they're by far preferred to action 1 (that goes to state # 3, which has poor reward) assert np.all(pi[:19, 0, 0] > 2 * pi[:19, 0, 1]) # make sure that states 3 & 4 have roughly uniform policies pi_34 = pi[:5, 3:5] assert np.allclose(pi_34, np.ones_like(pi_34) / 3.0) # check that states 1 & 2 have similar policies to each other assert np.allclose(pi[:19, 1, :], pi[:19, 2, :]) # check that in state 1, action 2 (which goes to state 4 with certainty) is # better than action 0 (which only gets there with some probability), and # that both are better than action 1 (which always goes to the bad state). assert np.all(pi[:19, 1, 2] > pi[:19, 1, 0]) assert np.all(pi[:19, 1, 0] > pi[:19, 1, 1]) # check that Dt[0] matches our initial state dist assert np.allclose(Dt[0], mdp.initial_state_dist)
def test_mce_irl_reasonable_mdp(model_class, model_kwargs): # test MCE IRL on the MDP mdp = ReasonableMDP() # demo occupancy measure V, Q, pi = mce_partition_fh(mdp) Dt, D = mce_occupancy_measures(mdp, pi=pi) rmodel = model_class(mdp.obs_dim, seed=13, **model_kwargs) opt = jaxopt.adam(1e-2) final_weights, final_counts = mce_irl(mdp, opt, rmodel, D, linf_eps=1e-3) assert np.allclose(final_counts, D, atol=1e-3, rtol=1e-3) # make sure weights have non-insane norm assert np.linalg.norm(final_weights) < 1000
def test_policy_om_random_mdp(): """Test that optimal policy occupancy measure ("om") for a random MDP makes sense.""" mdp = RandomMDP(n_states=16, n_actions=3, branch_factor=2, horizon=20, random_obs=True, obs_dim=5, generator_seed=42) V, Q, pi = mce_partition_fh(mdp) assert np.all(np.isfinite(V)) assert np.all(np.isfinite(Q)) assert np.all(np.isfinite(pi)) # Check it is a probability distribution along the last axis assert np.all(pi >= 0) assert np.allclose(np.sum(pi, axis=-1), 1) Dt, D = mce_occupancy_measures(mdp, pi=pi) assert np.all(np.isfinite(D)) assert np.any(D > 0) # expected number of state visits (over all states) should be equal to the # horizon assert np.allclose(np.sum(D), mdp.horizon)