Example #1
0
def test_script_recurrent(env: Env, policy: Policy):
    # Generate scripted version
    scripted = policy.double().script()

    # Compare results, tracing hidden manually
    hidden = policy.init_hidden()

    # Run one step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
    # Run second step
    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)

    # Test after reset
    hidden = policy.init_hidden()
    scripted.reset()

    sample = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(sample)
    act_reg, hidden = policy(obs, hidden)
    act_script = scripted(obs)
    to.testing.assert_allclose(act_reg, act_script)
Example #2
0
def test_recurrent_policy_batching(env: Env, policy: Policy, batch_size: int):
    assert policy.is_recurrent
    obs = np.stack([
        policy.env_spec.obs_space.sample_uniform() for _ in range(batch_size)
    ])  # shape = (batch_size, 4)
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout&co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden(batch_size)
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden.random_()
    assert hidden.shape == (batch_size, policy.hidden_size)

    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, _, hid_new = policy(obs, hidden)
    else:
        act, hid_new = policy(obs, hidden)
    assert hid_new.shape == (batch_size, policy.hidden_size)

    if batch_size > 1:
        # Try to use a subset of the batch
        subset = to.arange(batch_size // 2)
        if isinstance(policy, TwoHeadedRNNPolicyBase):
            act_sub, _, hid_sub = policy(obs[subset, :], hidden[subset, :])
        else:
            act_sub, hid_sub = policy(obs[subset, :], hidden[subset, :])
        to.testing.assert_allclose(act_sub, act[subset, :])
        to.testing.assert_allclose(hid_sub, hid_new[subset, :])
Example #3
0
def test_recurrent_policy_one_step(env: Env, policy: Policy):
    assert policy.is_recurrent
    obs = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout & co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden()
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden = to.rand_like(hidden)
    assert len(hidden) == policy.hidden_size

    # Test general conformity
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, otherhead, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size
    else:
        act, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size

    # Test reproducibility
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act2, otherhead2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(otherhead, otherhead2)
        to.testing.assert_allclose(hid_new2, hid_new2)
    else:
        act2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(hid_new2, hid_new2)
Example #4
0
def test_recurrent_policy_one_step(env: Env, policy: Policy):
    hid = policy.init_hidden()
    obs = env.obs_space.sample_uniform()
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, out2, hid = policy(obs, hid)
        assert isinstance(out2, to.Tensor)
    else:
        act, hid = policy(obs, hid)
    assert isinstance(act, to.Tensor) and isinstance(hid, to.Tensor)