def test_collate(): BATCH_SIZE = 100 for _ in range(100): tbl = [ {k: torch.rand((1, 3)) for k in ("a", "b", "c")} for _ in range(BATCH_SIZE) ] assert tolist(uncollate(collate(tbl))) == tolist(tbl)
def sample( self, start_obs: Union[None, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict, torch.Tensor]: data = [] obs = start_obs if start_obs else self.env.reset() for _ in range(self.path_length): action = self.policy(obs) next_obs, reward, done, info = self.env.step(action) single_data = { "obs": obs, "act": action, "next_obs": next_obs, "done": done, "rew": reward, } data.append(single_data) obs = next_obs if done: break data = collate(data) self._total_rewards.append(data["rew"].sum()) self._infos.append(torchify(info)) return obs, data
def log_local_epoch(self): epoch = { "total_reward": torch.stack(self._total_rewards), "info": collate(self._infos), } self._total_rewards = [] self._infos = [] return epoch
def test_multi_space(): for _ in range(NUM_SAMPLES): space = create_random_space() m = model(space) samples = [] for _ in range(size): samples.append({"obs": space.sample()}) sample = collate([torchify(s) for s in samples])["obs"] print("=" * 50) print("SPACE:", space) print("SAMPLE:", sample) forward = m.forward(sample) print("FORWARD:", forward) custom_equals(forward, sample)
def test_actor(): for device in "cpu", "cuda": env = make_env("CartPole-v1", device) critic = DqnCritic(env.observation_space, env.action_space, [256, 256], device) actor = DqnActor(critic, env.observation_space, env.action_space, device) for _ in range(300): obs = collate([ torchify(env.observation_space.sample(), device) for _ in range(20) ]) act = actor(obs) rand_act = torch.cat([ torchify(env.action_space.sample(), device) for _ in range(20) ]) assert torch.all(critic(obs, act) >= critic(obs, rand_act))
def test_collate_2(): for _ in range(100): tbl = {k: torch.rand((100, 3)) for k in ("a", "b", "c")} after = collate(uncollate(tbl)) for k in tbl.keys(): assert torch.all(tbl[k] == after[k])