Ejemplo n.º 1
0
    def action_with_log_prob(self, obs, deterministic=False):
        self.log("epsilon", torchify(self.epsilon))

        # random
        obs = self.obs_flat(obs)
        if (not deterministic and self.epsilon > 0
                and torch.rand(1).item() < self.epsilon):
            act = torch.cat([
                torchify(self.act_spec.sample()).long()
                for _ in range(len(obs))
            ]).to(obs.device)
            return act, None

        # not random
        n = self.act_spec.n
        obs_act = torch.cat(
            [
                obs.unsqueeze(1).repeat(1, n, 1),
                torch.eye(n).unsqueeze(0).repeat(len(obs), 1, 1).to(
                    obs.device),
            ],
            dim=2,
        )
        act = self.critic.qf(obs_act).argmax(dim=1)
        return act, None
 def forward(self, obs, skill):
     obs = {
         "observations": torchify(obs, self.device),
         "diayn": torchify(skill, self.device),
     }
     act = self.forward_policy.action(obs, deterministic=True)
     return untorchify(act)
Ejemplo n.º 3
0
def test_integration():
    for _ in range(100):
        for device in "cpu", "cuda":
            obs_space = create_random_space()
            act_space = create_random_space()
            buf = ReplayBuffer(obs_space, act_space, int(1e5), 1, device)
            print(buf.log_hyperparams())
            print("OBSSPEC", obs_space)
            print("ACTSPEC", act_space)

            step = {
                "obs": torchify(obs_space.sample(), device),
                "act": torchify(act_space.sample(), device),
                "rew": torchify(1.0, device),
                "next_obs": torchify(obs_space.sample(), device),
                "done": torchify(0, device),
            }
            buf.add(step)

            step2 = buf.sample()
            step = flatten(step)
            step2 = flatten(step2)
            assert step.keys() == step2.keys()
            for k in step:
                assert torch.all(step[k].cpu() == step2[k].cpu())

            print(buf.log_epoch())
Ejemplo n.º 4
0
 def step(self, action: Tensor) -> Tuple[Tensor, Tensor, Tensor, dict]:
     action = untorchify(action)
     next_obs, reward, done, info = self.env.step(action)
     return (
         torchify(next_obs, self.device),
         torchify(reward, self.device),
         torchify(done, self.device),
         info,
     )
def test_critic():
    for device in "cpu", "cuda":
        env = make_env("CartPole-v1", device)
        critic = DqnCritic(env.observation_space, env.action_space, [256, 256],
                           device)
        for _ in range(300):
            critic(
                torchify(env.observation_space.sample(), device),
                torchify(env.action_space.sample(), device),
            )
Ejemplo n.º 6
0
def test_random_space():
    for _ in range(100):
        obs_spec = create_random_space()
        act_spec = create_random_space()
        print(obs_spec)
        print(act_spec)
        c1 = DoubleQCritic(obs_spec, act_spec, [60, 50])
        c2 = DoubleVCritic(obs_spec, [60, 50])
        obs = torchify(obs_spec.sample())
        act = torchify(act_spec.sample())
        c1.forward(obs, act)
        c2.forward(obs)
Ejemplo n.º 7
0
    def sample(
        self,
        start_obs: Union[None, torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, dict, torch.Tensor]:
        data = []
        obs = start_obs if start_obs else self.env.reset()

        for _ in range(self.path_length):
            action = self.policy(obs)
            next_obs, reward, done, info = self.env.step(action)
            single_data = {
                "obs": obs,
                "act": action,
                "next_obs": next_obs,
                "done": done,
                "rew": reward,
            }
            data.append(single_data)
            obs = next_obs
            if done:
                break

        data = collate(data)

        self._total_rewards.append(data["rew"].sum())
        self._infos.append(torchify(info))
        return obs, data
def test_actor():
    for device in "cpu", "cuda":
        env = make_env("CartPole-v1", device)
        critic = DqnCritic(env.observation_space, env.action_space, [256, 256],
                           device)
        actor = DqnActor(critic, env.observation_space, env.action_space,
                         device)
        for _ in range(300):
            obs = collate([
                torchify(env.observation_space.sample(), device)
                for _ in range(20)
            ])
            act = actor(obs)
            rand_act = torch.cat([
                torchify(env.action_space.sample(), device) for _ in range(20)
            ])
            assert torch.all(critic(obs, act) >= critic(obs, rand_act))
Ejemplo n.º 9
0
def test_random_space_tanh():
    for _ in range(100):
        obs_spec = create_random_space()
        act_spec = create_random_space()
        print(obs_spec)
        print(act_spec)
        actor = TanhGaussianActor(obs_spec, act_spec, [60, 50])
        actor.action(torchify(obs_spec.sample()))
Ejemplo n.º 10
0
    def sample(self, batch_size=None):
        if not batch_size:
            batch_size = self.batch_size

        batch = torchify(self.buffer.sample(batch_size), self.device)
        return {
            "obs": self.obs_unflat(batch["obs"]),
            "act": self.act_unflat(batch["act"]),
            "next_obs": self.obs_unflat(batch["next_obs"]),
            "rew": batch["rew"],
            "done": batch["done"],
        }
Ejemplo n.º 11
0
    def sample(self, batch_size=None, with_index=False) -> dict:
        if not batch_size:
            batch_size = self.batch_size

        batch: dict = torchify(self.buffer.sample(batch_size), self.device)
        result = {
            "obs": self.obs_unflat(batch["obs"]),
            "act": self.act_unflat(batch["act"]),
            "next_obs": self.obs_unflat(batch["next_obs"]),
            "rew": batch["rew"],
            "done": batch["done"],
        }
        if with_index:
            result["index"] = batch["index"]
        return result
Ejemplo n.º 12
0
    def test_multi_space():
        for _ in range(NUM_SAMPLES):
            space = create_random_space()
            m = model(space)

            samples = []
            for _ in range(size):
                samples.append({"obs": space.sample()})
            sample = collate([torchify(s) for s in samples])["obs"]

            print("=" * 50)
            print("SPACE:", space)
            print("SAMPLE:", sample)
            forward = m.forward(sample)
            print("FORWARD:", forward)
            custom_equals(forward, sample)
 def backward(self, obs):
     obs = {"observations": torchify(obs, self.device)}
     act = self.backward_policy.action(obs, deterministic=True)
     return untorchify(act)
Ejemplo n.º 14
0
 def reset(self) -> dict:
     obs = self.env.reset()
     return torchify(obs, self.device)