def test_recurrent_policy_batching(env: Env, policy: Policy, batch_size: int): assert policy.is_recurrent obs = np.stack([ policy.env_spec.obs_space.sample_uniform() for _ in range(batch_size) ]) # shape = (batch_size, 4) obs = to.from_numpy(obs).to(dtype=to.get_default_dtype()) # Do this in evaluation mode to disable dropout&co policy.eval() # Create initial hidden state hidden = policy.init_hidden(batch_size) # Use a random one to ensure we don't just run into the 0-special-case hidden.random_() assert hidden.shape == (batch_size, policy.hidden_size) if isinstance(policy, TwoHeadedRNNPolicyBase): act, _, hid_new = policy(obs, hidden) else: act, hid_new = policy(obs, hidden) assert hid_new.shape == (batch_size, policy.hidden_size) if batch_size > 1: # Try to use a subset of the batch subset = to.arange(batch_size // 2) if isinstance(policy, TwoHeadedRNNPolicyBase): act_sub, _, hid_sub = policy(obs[subset, :], hidden[subset, :]) else: act_sub, hid_sub = policy(obs[subset, :], hidden[subset, :]) to.testing.assert_allclose(act_sub, act[subset, :]) to.testing.assert_allclose(hid_sub, hid_new[subset, :])
def test_recurrent_policy_one_step(env: Env, policy: Policy): assert policy.is_recurrent obs = policy.env_spec.obs_space.sample_uniform() obs = to.from_numpy(obs).to(dtype=to.get_default_dtype()) # Do this in evaluation mode to disable dropout & co policy.eval() # Create initial hidden state hidden = policy.init_hidden() # Use a random one to ensure we don't just run into the 0-special-case hidden = to.rand_like(hidden) assert len(hidden) == policy.hidden_size # Test general conformity if isinstance(policy, TwoHeadedRNNPolicyBase): act, otherhead, hid_new = policy(obs, hidden) assert len(hid_new) == policy.hidden_size else: act, hid_new = policy(obs, hidden) assert len(hid_new) == policy.hidden_size # Test reproducibility if isinstance(policy, TwoHeadedRNNPolicyBase): act2, otherhead2, hid_new2 = policy(obs, hidden) to.testing.assert_allclose(act, act2) to.testing.assert_allclose(otherhead, otherhead2) to.testing.assert_allclose(hid_new2, hid_new2) else: act2, hid_new2 = policy(obs, hidden) to.testing.assert_allclose(act, act2) to.testing.assert_allclose(hid_new2, hid_new2)
def evaluate(policy: Policy, inps: to.Tensor, targs: to.Tensor, windowed: bool, cascaded: bool, num_init_samples: int, hidden: Optional[to.Tensor] = None, loss_fcn=nn.MSELoss(), verbose: bool = True): if not inps.shape[0] == targs.shape[0]: raise pyrado.ShapeErr(given=inps, expected_match=targs) # Set policy, i.e. PyTorch nn.Module, back to evaluation mode policy.eval() targs = targs[num_init_samples:, :] if num_init_samples > 0 else targs preds = to.empty_like(targs) # Pass the first samples through the network in order to initialize the hidden state inp = inps[:num_init_samples, :] if num_init_samples > 0 else inps[ 0].unsqueeze(0) # running input pred, hidden = TSPred.predict(policy, inp, windowed, cascaded=False, hidden=hidden) # Run steps consecutively reusing the hidden state for idx in range(inps.shape[0] - num_init_samples): if not cascaded or idx == 0: # Forget the oldest input and append the latest input inp = inps[idx + num_init_samples, :].unsqueeze(0) else: # Forget the oldest input and append the latest prediction inp = pred pred, hidden = TSPred.predict(policy, inp, windowed, cascaded=False, hidden=hidden) preds[idx, :] = pred # Compute loss for the entire data set at once loss = loss_fcn(targs, preds) if verbose: print_cbt( f'The {policy.name} policy with {policy.num_param} parameters predicted {inps.shape[0]} data points ' f'with a loss of {loss.item():.4e}.', 'g') # Set policy, i.e. PyTorch nn.Module, back to training mode policy.train() return preds, loss