Ejemplo n.º 1
0
def test_WithReturns_bootstrap():
    Transition = attr.make_class("Transition",
                                 ("reward", "obs", "done", "critic_value"))
    trajectory = [Transition(i, Obs(), False, 2) for i in range(4)]
    transform = T.WithReturns(discount=1.0, norm_returns=False)

    transformed = transform(trajectory)
    assert len(transformed) == len(trajectory)
    assert transformed[-1].retrn == 2
    assert approx_eq(transformed[-2].retrn, 2 + 2)
Ejemplo n.º 2
0
def test_WithReturns_Normalize():
    Transition = attr.make_class("Transition",
                                 ("reward", "obs", "done", "critic_value"))
    transform = T.WithReturns(norm_returns=True)
    transform.normalizer = mock.MagicMock(wraps=transform.normalizer)

    trajectory1 = [Transition(i, Obs(), i == 3, 2) for i in range(4)]
    transform(trajectory1)
    assert transform.normalizer.normalize.called

    trajectory2 = [Transition(i, Obs(), False, 2) for i in range(4)]
    transform(trajectory2)
    assert transform.normalizer.denormalize.called
Ejemplo n.º 3
0
def test_WithReturns():
    Transition = attr.make_class("Transition", ("reward", "obs", "done"))
    trajectory = [Transition(i, Obs(), i == 3) for i in range(4)]
    transform = T.WithReturns(discount=0.1, norm_returns=True)

    transformed = transform(trajectory)
    assert isinstance(transformed, list)
    assert len(transformed) == len(trajectory)
    assert all(hasattr(t, "retrn") for t in transformed)
    assert all(isinstance(t.retrn, float) for t in transformed)
    assert all(hasattr(t, "obs") for t in transformed)
    assert all(isinstance(t.obs, Obs) for t in transformed)
    assert abs(sum(t.retrn for t in transformed)) < 1e-5

    running_mean = transform.normalizer.running_means[None]
    assert isinstance(running_mean, torch.Tensor)
    assert running_mean.shape == (1, )
Ejemplo n.º 4
0
 def add_trajectories_to_engine(engine):
     engine.state.trajectories = Trajectories(
         T.WithReturns(discount=discount, norm_returns=norm_returns))