Пример #1
0
def test_policy_om_reasonable_mdp():
  # MDP described above
  mdp = ReasonableMDP()
  # get policy etc. for our MDP
  V, Q, pi = mce_partition_fh(mdp)
  Dt, D = mce_occupancy_measures(mdp, pi=pi)
  assert np.all(np.isfinite(V))
  assert np.all(np.isfinite(Q))
  assert np.all(np.isfinite(pi))
  assert np.all(np.isfinite(Dt))
  assert np.all(np.isfinite(D))
  # check that actions 0 & 2 (which go to states 1 & 2) are roughly equal
  assert np.allclose(pi[:19, 0, 0], pi[:19, 0, 2])
  # also check that they're by far preferred to action 1 (that goes to state
  # 3, which has poor reward)
  assert np.all(pi[:19, 0, 0] > 2 * pi[:19, 0, 1])
  # make sure that states 3 & 4 have roughly uniform policies
  pi_34 = pi[:5, 3:5]
  assert np.allclose(pi_34, np.ones_like(pi_34) / 3.0)
  # check that states 1 & 2 have similar policies to each other
  assert np.allclose(pi[:19, 1, :], pi[:19, 2, :])
  # check that in state 1, action 2 (which goes to state 4 with certainty) is
  # better than action 0 (which only gets there with some probability), and
  # that both are better than action 1 (which always goes to the bad state).
  assert np.all(pi[:19, 1, 2] > pi[:19, 1, 0])
  assert np.all(pi[:19, 1, 0] > pi[:19, 1, 1])
  # check that Dt[0] matches our initial state dist
  assert np.allclose(Dt[0], mdp.initial_state_dist)
Пример #2
0
def test_mce_irl_reasonable_mdp(model_class, model_kwargs):
  # test MCE IRL on the MDP
  mdp = ReasonableMDP()

  # demo occupancy measure
  V, Q, pi = mce_partition_fh(mdp)
  Dt, D = mce_occupancy_measures(mdp, pi=pi)

  rmodel = model_class(mdp.obs_dim, seed=13, **model_kwargs)
  opt = jaxopt.adam(1e-2)
  final_weights, final_counts = mce_irl(mdp, opt, rmodel, D, linf_eps=1e-3)

  assert np.allclose(final_counts, D, atol=1e-3, rtol=1e-3)
  # make sure weights have non-insane norm
  assert np.linalg.norm(final_weights) < 1000
Пример #3
0
def test_policy_om_random_mdp():
    """Test that optimal policy occupancy measure ("om") for a random MDP is sane."""
    mdp = gym.make("imitation/Random-v0")
    V, Q, pi = mce_partition_fh(mdp)
    assert np.all(np.isfinite(V))
    assert np.all(np.isfinite(Q))
    assert np.all(np.isfinite(pi))
    # Check it is a probability distribution along the last axis
    assert np.all(pi >= 0)
    assert np.allclose(np.sum(pi, axis=-1), 1)

    Dt, D = mce_occupancy_measures(mdp, pi=pi)
    assert np.all(np.isfinite(D))
    assert np.any(D > 0)
    # expected number of state visits (over all states) should be equal to the
    # horizon
    assert np.allclose(np.sum(D), mdp.horizon)