def test_script_recurrent(env: Env, policy: Policy): # Generate scripted version scripted = policy.double().script() # Compare results, tracing hidden manually hidden = policy.init_hidden() # Run one step sample = policy.env_spec.obs_space.sample_uniform() obs = to.from_numpy(sample) act_reg, hidden = policy(obs, hidden) act_script = scripted(obs) to.testing.assert_allclose(act_reg, act_script) # Run second step sample = policy.env_spec.obs_space.sample_uniform() obs = to.from_numpy(sample) act_reg, hidden = policy(obs, hidden) act_script = scripted(obs) to.testing.assert_allclose(act_reg, act_script) # Test after reset hidden = policy.init_hidden() scripted.reset() sample = policy.env_spec.obs_space.sample_uniform() obs = to.from_numpy(sample) act_reg, hidden = policy(obs, hidden) act_script = scripted(obs) to.testing.assert_allclose(act_reg, act_script)
def test_recurrent_policy_batching(env: Env, policy: Policy, batch_size: int): assert policy.is_recurrent obs = np.stack([ policy.env_spec.obs_space.sample_uniform() for _ in range(batch_size) ]) # shape = (batch_size, 4) obs = to.from_numpy(obs).to(dtype=to.get_default_dtype()) # Do this in evaluation mode to disable dropout&co policy.eval() # Create initial hidden state hidden = policy.init_hidden(batch_size) # Use a random one to ensure we don't just run into the 0-special-case hidden.random_() assert hidden.shape == (batch_size, policy.hidden_size) if isinstance(policy, TwoHeadedRNNPolicyBase): act, _, hid_new = policy(obs, hidden) else: act, hid_new = policy(obs, hidden) assert hid_new.shape == (batch_size, policy.hidden_size) if batch_size > 1: # Try to use a subset of the batch subset = to.arange(batch_size // 2) if isinstance(policy, TwoHeadedRNNPolicyBase): act_sub, _, hid_sub = policy(obs[subset, :], hidden[subset, :]) else: act_sub, hid_sub = policy(obs[subset, :], hidden[subset, :]) to.testing.assert_allclose(act_sub, act[subset, :]) to.testing.assert_allclose(hid_sub, hid_new[subset, :])
def test_recurrent_policy_one_step(env: Env, policy: Policy): assert policy.is_recurrent obs = policy.env_spec.obs_space.sample_uniform() obs = to.from_numpy(obs).to(dtype=to.get_default_dtype()) # Do this in evaluation mode to disable dropout & co policy.eval() # Create initial hidden state hidden = policy.init_hidden() # Use a random one to ensure we don't just run into the 0-special-case hidden = to.rand_like(hidden) assert len(hidden) == policy.hidden_size # Test general conformity if isinstance(policy, TwoHeadedRNNPolicyBase): act, otherhead, hid_new = policy(obs, hidden) assert len(hid_new) == policy.hidden_size else: act, hid_new = policy(obs, hidden) assert len(hid_new) == policy.hidden_size # Test reproducibility if isinstance(policy, TwoHeadedRNNPolicyBase): act2, otherhead2, hid_new2 = policy(obs, hidden) to.testing.assert_allclose(act, act2) to.testing.assert_allclose(otherhead, otherhead2) to.testing.assert_allclose(hid_new2, hid_new2) else: act2, hid_new2 = policy(obs, hidden) to.testing.assert_allclose(act, act2) to.testing.assert_allclose(hid_new2, hid_new2)
def test_recurrent_policy_one_step(env: Env, policy: Policy): hid = policy.init_hidden() obs = env.obs_space.sample_uniform() obs = to.from_numpy(obs).to(dtype=to.get_default_dtype()) if isinstance(policy, TwoHeadedRNNPolicyBase): act, out2, hid = policy(obs, hid) assert isinstance(out2, to.Tensor) else: act, hid = policy(obs, hid) assert isinstance(act, to.Tensor) and isinstance(hid, to.Tensor)