示例#1
0
def test_collate():
    BATCH_SIZE = 100
    for _ in range(100):
        tbl = [
            {k: torch.rand((1, 3)) for k in ("a", "b", "c")} for _ in range(BATCH_SIZE)
        ]
        assert tolist(uncollate(collate(tbl))) == tolist(tbl)
示例#2
0
    def sample(
        self,
        start_obs: Union[None, torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, dict, torch.Tensor]:
        data = []
        obs = start_obs if start_obs else self.env.reset()

        for _ in range(self.path_length):
            action = self.policy(obs)
            next_obs, reward, done, info = self.env.step(action)
            single_data = {
                "obs": obs,
                "act": action,
                "next_obs": next_obs,
                "done": done,
                "rew": reward,
            }
            data.append(single_data)
            obs = next_obs
            if done:
                break

        data = collate(data)

        self._total_rewards.append(data["rew"].sum())
        self._infos.append(torchify(info))
        return obs, data
示例#3
0
 def log_local_epoch(self):
     epoch = {
         "total_reward": torch.stack(self._total_rewards),
         "info": collate(self._infos),
     }
     self._total_rewards = []
     self._infos = []
     return epoch
    def test_multi_space():
        for _ in range(NUM_SAMPLES):
            space = create_random_space()
            m = model(space)

            samples = []
            for _ in range(size):
                samples.append({"obs": space.sample()})
            sample = collate([torchify(s) for s in samples])["obs"]

            print("=" * 50)
            print("SPACE:", space)
            print("SAMPLE:", sample)
            forward = m.forward(sample)
            print("FORWARD:", forward)
            custom_equals(forward, sample)
def test_actor():
    for device in "cpu", "cuda":
        env = make_env("CartPole-v1", device)
        critic = DqnCritic(env.observation_space, env.action_space, [256, 256],
                           device)
        actor = DqnActor(critic, env.observation_space, env.action_space,
                         device)
        for _ in range(300):
            obs = collate([
                torchify(env.observation_space.sample(), device)
                for _ in range(20)
            ])
            act = actor(obs)
            rand_act = torch.cat([
                torchify(env.action_space.sample(), device) for _ in range(20)
            ])
            assert torch.all(critic(obs, act) >= critic(obs, rand_act))
示例#6
0
def test_collate_2():
    for _ in range(100):
        tbl = {k: torch.rand((100, 3)) for k in ("a", "b", "c")}
        after = collate(uncollate(tbl))
        for k in tbl.keys():
            assert torch.all(tbl[k] == after[k])