def test_WithReturns_bootstrap(): Transition = attr.make_class("Transition", ("reward", "obs", "done", "critic_value")) trajectory = [Transition(i, Obs(), False, 2) for i in range(4)] transform = T.WithReturns(discount=1.0, norm_returns=False) transformed = transform(trajectory) assert len(transformed) == len(trajectory) assert transformed[-1].retrn == 2 assert approx_eq(transformed[-2].retrn, 2 + 2)
def test_WithReturns_Normalize(): Transition = attr.make_class("Transition", ("reward", "obs", "done", "critic_value")) transform = T.WithReturns(norm_returns=True) transform.normalizer = mock.MagicMock(wraps=transform.normalizer) trajectory1 = [Transition(i, Obs(), i == 3, 2) for i in range(4)] transform(trajectory1) assert transform.normalizer.normalize.called trajectory2 = [Transition(i, Obs(), False, 2) for i in range(4)] transform(trajectory2) assert transform.normalizer.denormalize.called
def test_WithReturns(): Transition = attr.make_class("Transition", ("reward", "obs", "done")) trajectory = [Transition(i, Obs(), i == 3) for i in range(4)] transform = T.WithReturns(discount=0.1, norm_returns=True) transformed = transform(trajectory) assert isinstance(transformed, list) assert len(transformed) == len(trajectory) assert all(hasattr(t, "retrn") for t in transformed) assert all(isinstance(t.retrn, float) for t in transformed) assert all(hasattr(t, "obs") for t in transformed) assert all(isinstance(t.obs, Obs) for t in transformed) assert abs(sum(t.retrn for t in transformed)) < 1e-5 running_mean = transform.normalizer.running_means[None] assert isinstance(running_mean, torch.Tensor) assert running_mean.shape == (1, )
def add_trajectories_to_engine(engine): engine.state.trajectories = Trajectories( T.WithReturns(discount=discount, norm_returns=norm_returns))