Beispiel #1
0
def test_recurrent_policy_batching(env: Env, policy: Policy, batch_size: int):
    assert policy.is_recurrent
    obs = np.stack([
        policy.env_spec.obs_space.sample_uniform() for _ in range(batch_size)
    ])  # shape = (batch_size, 4)
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout&co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden(batch_size)
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden.random_()
    assert hidden.shape == (batch_size, policy.hidden_size)

    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, _, hid_new = policy(obs, hidden)
    else:
        act, hid_new = policy(obs, hidden)
    assert hid_new.shape == (batch_size, policy.hidden_size)

    if batch_size > 1:
        # Try to use a subset of the batch
        subset = to.arange(batch_size // 2)
        if isinstance(policy, TwoHeadedRNNPolicyBase):
            act_sub, _, hid_sub = policy(obs[subset, :], hidden[subset, :])
        else:
            act_sub, hid_sub = policy(obs[subset, :], hidden[subset, :])
        to.testing.assert_allclose(act_sub, act[subset, :])
        to.testing.assert_allclose(hid_sub, hid_new[subset, :])
Beispiel #2
0
def test_recurrent_policy_one_step(env: Env, policy: Policy):
    assert policy.is_recurrent
    obs = policy.env_spec.obs_space.sample_uniform()
    obs = to.from_numpy(obs).to(dtype=to.get_default_dtype())

    # Do this in evaluation mode to disable dropout & co
    policy.eval()

    # Create initial hidden state
    hidden = policy.init_hidden()
    # Use a random one to ensure we don't just run into the 0-special-case
    hidden = to.rand_like(hidden)
    assert len(hidden) == policy.hidden_size

    # Test general conformity
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act, otherhead, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size
    else:
        act, hid_new = policy(obs, hidden)
        assert len(hid_new) == policy.hidden_size

    # Test reproducibility
    if isinstance(policy, TwoHeadedRNNPolicyBase):
        act2, otherhead2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(otherhead, otherhead2)
        to.testing.assert_allclose(hid_new2, hid_new2)
    else:
        act2, hid_new2 = policy(obs, hidden)
        to.testing.assert_allclose(act, act2)
        to.testing.assert_allclose(hid_new2, hid_new2)
    def evaluate(policy: Policy,
                 inps: to.Tensor,
                 targs: to.Tensor,
                 windowed: bool,
                 cascaded: bool,
                 num_init_samples: int,
                 hidden: Optional[to.Tensor] = None,
                 loss_fcn=nn.MSELoss(),
                 verbose: bool = True):
        if not inps.shape[0] == targs.shape[0]:
            raise pyrado.ShapeErr(given=inps, expected_match=targs)

        # Set policy, i.e. PyTorch nn.Module, back to evaluation mode
        policy.eval()

        targs = targs[num_init_samples:, :] if num_init_samples > 0 else targs
        preds = to.empty_like(targs)

        # Pass the first samples through the network in order to initialize the hidden state
        inp = inps[:num_init_samples, :] if num_init_samples > 0 else inps[
            0].unsqueeze(0)  # running input
        pred, hidden = TSPred.predict(policy,
                                      inp,
                                      windowed,
                                      cascaded=False,
                                      hidden=hidden)

        # Run steps consecutively reusing the hidden state
        for idx in range(inps.shape[0] - num_init_samples):
            if not cascaded or idx == 0:
                # Forget the oldest input and append the latest input
                inp = inps[idx + num_init_samples, :].unsqueeze(0)
            else:
                # Forget the oldest input and append the latest prediction
                inp = pred

            pred, hidden = TSPred.predict(policy,
                                          inp,
                                          windowed,
                                          cascaded=False,
                                          hidden=hidden)
            preds[idx, :] = pred

        # Compute loss for the entire data set at once
        loss = loss_fcn(targs, preds)

        if verbose:
            print_cbt(
                f'The {policy.name} policy with {policy.num_param} parameters predicted {inps.shape[0]} data points '
                f'with a loss of {loss.item():.4e}.', 'g')

        # Set policy, i.e. PyTorch nn.Module, back to training mode
        policy.train()

        return preds, loss