def action_with_log_prob(self, obs, deterministic=False): self.log("epsilon", torchify(self.epsilon)) # random obs = self.obs_flat(obs) if (not deterministic and self.epsilon > 0 and torch.rand(1).item() < self.epsilon): act = torch.cat([ torchify(self.act_spec.sample()).long() for _ in range(len(obs)) ]).to(obs.device) return act, None # not random n = self.act_spec.n obs_act = torch.cat( [ obs.unsqueeze(1).repeat(1, n, 1), torch.eye(n).unsqueeze(0).repeat(len(obs), 1, 1).to( obs.device), ], dim=2, ) act = self.critic.qf(obs_act).argmax(dim=1) return act, None
def forward(self, obs, skill): obs = { "observations": torchify(obs, self.device), "diayn": torchify(skill, self.device), } act = self.forward_policy.action(obs, deterministic=True) return untorchify(act)
def test_integration(): for _ in range(100): for device in "cpu", "cuda": obs_space = create_random_space() act_space = create_random_space() buf = ReplayBuffer(obs_space, act_space, int(1e5), 1, device) print(buf.log_hyperparams()) print("OBSSPEC", obs_space) print("ACTSPEC", act_space) step = { "obs": torchify(obs_space.sample(), device), "act": torchify(act_space.sample(), device), "rew": torchify(1.0, device), "next_obs": torchify(obs_space.sample(), device), "done": torchify(0, device), } buf.add(step) step2 = buf.sample() step = flatten(step) step2 = flatten(step2) assert step.keys() == step2.keys() for k in step: assert torch.all(step[k].cpu() == step2[k].cpu()) print(buf.log_epoch())
def step(self, action: Tensor) -> Tuple[Tensor, Tensor, Tensor, dict]: action = untorchify(action) next_obs, reward, done, info = self.env.step(action) return ( torchify(next_obs, self.device), torchify(reward, self.device), torchify(done, self.device), info, )
def test_critic(): for device in "cpu", "cuda": env = make_env("CartPole-v1", device) critic = DqnCritic(env.observation_space, env.action_space, [256, 256], device) for _ in range(300): critic( torchify(env.observation_space.sample(), device), torchify(env.action_space.sample(), device), )
def test_random_space(): for _ in range(100): obs_spec = create_random_space() act_spec = create_random_space() print(obs_spec) print(act_spec) c1 = DoubleQCritic(obs_spec, act_spec, [60, 50]) c2 = DoubleVCritic(obs_spec, [60, 50]) obs = torchify(obs_spec.sample()) act = torchify(act_spec.sample()) c1.forward(obs, act) c2.forward(obs)
def sample( self, start_obs: Union[None, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict, torch.Tensor]: data = [] obs = start_obs if start_obs else self.env.reset() for _ in range(self.path_length): action = self.policy(obs) next_obs, reward, done, info = self.env.step(action) single_data = { "obs": obs, "act": action, "next_obs": next_obs, "done": done, "rew": reward, } data.append(single_data) obs = next_obs if done: break data = collate(data) self._total_rewards.append(data["rew"].sum()) self._infos.append(torchify(info)) return obs, data
def test_actor(): for device in "cpu", "cuda": env = make_env("CartPole-v1", device) critic = DqnCritic(env.observation_space, env.action_space, [256, 256], device) actor = DqnActor(critic, env.observation_space, env.action_space, device) for _ in range(300): obs = collate([ torchify(env.observation_space.sample(), device) for _ in range(20) ]) act = actor(obs) rand_act = torch.cat([ torchify(env.action_space.sample(), device) for _ in range(20) ]) assert torch.all(critic(obs, act) >= critic(obs, rand_act))
def test_random_space_tanh(): for _ in range(100): obs_spec = create_random_space() act_spec = create_random_space() print(obs_spec) print(act_spec) actor = TanhGaussianActor(obs_spec, act_spec, [60, 50]) actor.action(torchify(obs_spec.sample()))
def sample(self, batch_size=None): if not batch_size: batch_size = self.batch_size batch = torchify(self.buffer.sample(batch_size), self.device) return { "obs": self.obs_unflat(batch["obs"]), "act": self.act_unflat(batch["act"]), "next_obs": self.obs_unflat(batch["next_obs"]), "rew": batch["rew"], "done": batch["done"], }
def sample(self, batch_size=None, with_index=False) -> dict: if not batch_size: batch_size = self.batch_size batch: dict = torchify(self.buffer.sample(batch_size), self.device) result = { "obs": self.obs_unflat(batch["obs"]), "act": self.act_unflat(batch["act"]), "next_obs": self.obs_unflat(batch["next_obs"]), "rew": batch["rew"], "done": batch["done"], } if with_index: result["index"] = batch["index"] return result
def test_multi_space(): for _ in range(NUM_SAMPLES): space = create_random_space() m = model(space) samples = [] for _ in range(size): samples.append({"obs": space.sample()}) sample = collate([torchify(s) for s in samples])["obs"] print("=" * 50) print("SPACE:", space) print("SAMPLE:", sample) forward = m.forward(sample) print("FORWARD:", forward) custom_equals(forward, sample)
def backward(self, obs): obs = {"observations": torchify(obs, self.device)} act = self.backward_policy.action(obs, deterministic=True) return untorchify(act)
def reset(self) -> dict: obs = self.env.reset() return torchify(obs, self.device)