def __init__(self, path: str, device: str):
        path = Path(path).absolute()
        with open(path.parent / "hyperparams.json", "r") as f:
            hyperparams = json.load(f)

        checkpoint = torch.load(path, "cpu")
        self.ns = hyperparams["diayn"]["num_skills"]
        env = make_env("Ant-v4", device)

        diayn = DiscreteDiayn(
            env.observation_space,
            hyperparams["diayn"]["model"]["hidden_dim"],
            self.ns,
            hyperparams["diayn"]["reward_weight"],
            _truncate=hyperparams["diayn"]["truncate"],
        )
        env = DiaynWrapper(env, diayn)
        self.forward_policy = TanhGaussianActor(
            obs_spec=env.observation_space,
            act_spec=env.action_space,
            hidden_dim=hyperparams["sac"]["actor"]["policy"]["hidden_dim"],
        )
        self.forward_policy.load_state_dict(checkpoint["sac"]["actor"]["state_dict"])
        self.forward_policy.to(device)
        self.device = device
def test_no_nan():
    """Test for no nans in all parameters.

    The main takeaway from this test is that you must set the learning
    rates low or else the parameters will tend to nan.

    """

    env = gym.make("InvertedPendulum-v2")
    a = TanhGaussianActor(env.observation_space, env.action_space, [256, 256])
    c = DoubleQCritic(env.observation_space, env.action_space, [256, 256])
    s = SAC(actor=a,
            critic=c,
            _device="cuda",
            _act_dim=len(env.action_space.low))

    batch = create_batch(env)
    for t in range(200):
        print("iteration", t)
        s.update(batch, t)
        for key, v in a.state_dict().items():
            print("actor", key)
            assert torch.any(torch.isnan(v)) == False
        for key, v in c.state_dict().items():
            print("actor", key)
            assert torch.any(torch.isnan(v)) == False
Beispiel #3
0
def test_random_space_tanh():
    for _ in range(100):
        obs_spec = create_random_space()
        act_spec = create_random_space()
        print(obs_spec)
        print(act_spec)
        actor = TanhGaussianActor(obs_spec, act_spec, [60, 50])
        actor.action(torchify(obs_spec.sample()))
Beispiel #4
0
def test_integration_tanh():
    obs_spec = Box(low=np.zeros(10, dtype=np.float32),
                   high=np.ones(10, dtype=np.float32))
    act_spec = Box(low=np.zeros(3, dtype=np.float32),
                   high=np.ones(3, dtype=np.float32))
    actor = TanhGaussianActor(obs_spec, act_spec, [60, 50])
    obs = torch.rand((100, 10))
    actions = actor.action(obs)
    assert actions.shape == (100, 3)
class DiaynPolicy(Policy):
    @staticmethod
    def check(path: str):
        path = Path(path).absolute()
        if not (path.parent / "hyperparams.json").exists():
            return False
        with open(path.parent / "hyperparams.json", "r") as f:
            hyperparams = json.load(f)
        return "diayn" in hyperparams

    def __init__(self, path: str, device: str):
        path = Path(path).absolute()
        with open(path.parent / "hyperparams.json", "r") as f:
            hyperparams = json.load(f)

        checkpoint = torch.load(path, "cpu")
        self.ns = hyperparams["diayn"]["num_skills"]
        env = make_env("Ant-v4", device)

        diayn = DiscreteDiayn(
            env.observation_space,
            hyperparams["diayn"]["model"]["hidden_dim"],
            self.ns,
            hyperparams["diayn"]["reward_weight"],
            _truncate=hyperparams["diayn"]["truncate"],
        )
        env = DiaynWrapper(env, diayn)
        self.forward_policy = TanhGaussianActor(
            obs_spec=env.observation_space,
            act_spec=env.action_space,
            hidden_dim=hyperparams["sac"]["actor"]["policy"]["hidden_dim"],
        )
        self.forward_policy.load_state_dict(checkpoint["sac"]["actor"]["state_dict"])
        self.forward_policy.to(device)
        self.device = device

    @property
    def num_skills(self):
        return self.ns

    def forward(self, obs, skill):
        obs = {
            "observations": torchify(obs, self.device),
            "diayn": torchify(skill, self.device),
        }
        act = self.forward_policy.action(obs, deterministic=True)
        return untorchify(act)

    def backward(self, obs):
        raise NotImplementedError
def test_critic_target_update():
    env = gym.make("InvertedPendulum-v2")
    a = TanhGaussianActor(env.observation_space, env.action_space, [256, 256])
    c = DoubleQCritic(env.observation_space, env.action_space, [256, 256])
    s = SAC(
        actor=a,
        critic=c,
        _device="cuda",
        _act_dim=len(env.action_space.low),
        _critic_target_update_frequency=200,
    )

    batch = create_batch(env)
    cp_before = s.critic_target.state_dict()
    for t in range(100):
        s.update(batch, t + 1)
        cp_after = s.critic_target.state_dict()

    for k, v in cp_before.items():
        v2 = cp_after[k]
        assert torch.all(v == v2)

    cp_before = {k: v.clone() for k, v in s.critic_target.state_dict().items()}
    s.update(batch, 1000)
    cp_after = s.critic_target.state_dict()

    for k, v in cp_before.items():
        v2 = cp_after[k]
        assert not torch.all(v == v2)
Beispiel #7
0
def test_log_tanh():
    obs_spec = Box(low=np.zeros(10, dtype=np.float32),
                   high=np.ones(10, dtype=np.float32))
    act_spec = Box(low=np.zeros(3, dtype=np.float32),
                   high=np.ones(3, dtype=np.float32))
    actor = TanhGaussianActor(obs_spec, act_spec, [60, 50])
    obs = torch.rand((100, 10))
    actor.action(obs)
    print(flatten(actor.log_hyperparams()).keys())
    print(flatten(actor.log_epoch()).keys())
def test_actor_loss_decrease():
    env = gym.make("InvertedPendulum-v2")
    a = TanhGaussianActor(env.observation_space, env.action_space, [256, 256])
    c = DoubleQCritic(env.observation_space, env.action_space, [256, 256])
    s = SAC(actor=a,
            critic=c,
            _device="cuda",
            _act_dim=len(env.action_space.low))

    batch = create_batch(env)
    batch = {"obs": batch["obs"]}
    s.update_actor_and_alpha(**batch)
    loss_before = s.log_local_epoch()["actor/loss"]
    for _ in range(200):
        s.update_actor_and_alpha(**batch)
    loss_after = s.log_local_epoch()["actor/loss"]
    assert loss_after < loss_before + 0.2
def test_integration():
    env = gym.make("InvertedPendulum-v2")

    for device in "cpu", "cuda":
        a = TanhGaussianActor(env.observation_space, env.action_space,
                              [256, 256])
        c = DoubleQCritic(env.observation_space, env.action_space, [256, 256])
        s = SAC(actor=a,
                critic=c,
                _device=device,
                _act_dim=len(env.action_space.low))

        print(flatten(s.log_hyperparams()).keys())

        batch = create_batch(env, device)
        for t in range(10):
            s.update(batch, t)

        print(flatten(s.log_epoch()).keys())
def test_critic_value_increase():
    env = gym.make("InvertedPendulum-v2")
    a = TanhGaussianActor(env.observation_space, env.action_space, [256, 256])
    c = DoubleQCritic(env.observation_space, env.action_space, [256, 256])
    s = SAC(actor=a,
            critic=c,
            _device="cuda",
            _act_dim=len(env.action_space.low))

    batch = create_batch(env)

    s.update_critic(**batch)
    q1_before = s.log_local_epoch()["critic/q1"].mean()
    q2_before = s.log_local_epoch()["critic/q2"].mean()
    for _ in range(200):
        s.update_critic(**batch)
    q1_after = s.log_local_epoch()["critic/q1"].mean()
    q2_after = s.log_local_epoch()["critic/q2"].mean()
    assert q1_after > q1_before - 0.2
    assert q2_after > q2_before - 0.2