コード例 #1
0
def test_policy_om_reasonable_mdp():
    # MDP described above
    mdp = ReasonableMDP()
    # get policy etc. for our MDP
    V, Q, pi = mce_partition_fh(mdp)
    Dt, D = mce_occupancy_measures(mdp, pi=pi)
    assert np.all(np.isfinite(V))
    assert np.all(np.isfinite(Q))
    assert np.all(np.isfinite(pi))
    assert np.all(np.isfinite(Dt))
    assert np.all(np.isfinite(D))
    # check that actions 0 & 2 (which go to states 1 & 2) are roughly equal
    assert np.allclose(pi[:19, 0, 0], pi[:19, 0, 2])
    # also check that they're by far preferred to action 1 (that goes to state
    # 3, which has poor reward)
    assert np.all(pi[:19, 0, 0] > 2 * pi[:19, 0, 1])
    # make sure that states 3 & 4 have roughly uniform policies
    pi_34 = pi[:5, 3:5]
    assert np.allclose(pi_34, np.ones_like(pi_34) / 3.0)
    # check that states 1 & 2 have similar policies to each other
    assert np.allclose(pi[:19, 1, :], pi[:19, 2, :])
    # check that in state 1, action 2 (which goes to state 4 with certainty) is
    # better than action 0 (which only gets there with some probability), and
    # that both are better than action 1 (which always goes to the bad state).
    assert np.all(pi[:19, 1, 2] > pi[:19, 1, 0])
    assert np.all(pi[:19, 1, 0] > pi[:19, 1, 1])
    # check that Dt[0] matches our initial state dist
    assert np.allclose(Dt[0], mdp.initial_state_dist)
コード例 #2
0
def test_mce_irl_reasonable_mdp(model_class, model_kwargs):
    # test MCE IRL on the MDP
    mdp = ReasonableMDP()

    # demo occupancy measure
    V, Q, pi = mce_partition_fh(mdp)
    Dt, D = mce_occupancy_measures(mdp, pi=pi)

    rmodel = model_class(mdp.obs_dim, seed=13, **model_kwargs)
    opt = jaxopt.adam(1e-2)
    final_weights, final_counts = mce_irl(mdp, opt, rmodel, D, linf_eps=1e-3)

    assert np.allclose(final_counts, D, atol=1e-3, rtol=1e-3)
    # make sure weights have non-insane norm
    assert np.linalg.norm(final_weights) < 1000
コード例 #3
0
def test_policy_om_random_mdp():
    """Test that optimal policy occupancy measure ("om") for a random MDP makes
  sense."""
    mdp = RandomMDP(n_states=16,
                    n_actions=3,
                    branch_factor=2,
                    horizon=20,
                    random_obs=True,
                    obs_dim=5,
                    generator_seed=42)
    V, Q, pi = mce_partition_fh(mdp)
    assert np.all(np.isfinite(V))
    assert np.all(np.isfinite(Q))
    assert np.all(np.isfinite(pi))
    # Check it is a probability distribution along the last axis
    assert np.all(pi >= 0)
    assert np.allclose(np.sum(pi, axis=-1), 1)

    Dt, D = mce_occupancy_measures(mdp, pi=pi)
    assert np.all(np.isfinite(D))
    assert np.any(D > 0)
    # expected number of state visits (over all states) should be equal to the
    # horizon
    assert np.allclose(np.sum(D), mdp.horizon)