Exemplo n.º 1
0
    def test_config_init(self, train_config):
        c = train_config
        config = RAINBOW.generate_config({})
        config["frame_config"]["models"] = ["QNet", "QNet"]
        config["frame_config"]["model_kwargs"] = [{
            "state_dim": c.observe_dim,
            "action_num": c.action_num
        }] * 2
        rainbow = RAINBOW.init_from_config(config)

        old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
        action = t.zeros([1, 1], dtype=t.int)
        rainbow.store_episode([{
            "state": {
                "state": old_state
            },
            "action": {
                "action": action
            },
            "next_state": {
                "state": state
            },
            "reward": 0,
            "terminal": False,
        } for _ in range(3)])
        rainbow.update()
Exemplo n.º 2
0
 def rainbow_train(self, train_config):
     c = train_config
     # cpu is faster for testing full training.
     q_net = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     q_net_t = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     rainbow = RAINBOW(q_net, q_net_t,
                       t.optim.Adam,
                       c.value_min,
                       c.value_max,
                       reward_future_steps=c.reward_future_steps,
                       replay_device="cpu",
                       replay_size=c.replay_size)
     return rainbow
Exemplo n.º 3
0
 def rainbow(self, train_config, device, dtype):
     c = train_config
     q_net = smw(QNet(c.observe_dim, c.action_num)
                 .type(dtype).to(device), device, device)
     q_net_t = smw(QNet(c.observe_dim, c.action_num)
                   .type(dtype).to(device), device, device)
     rainbow = RAINBOW(q_net, q_net_t,
                       t.optim.Adam,
                       c.value_min,
                       c.value_max,
                       reward_future_steps=c.reward_future_steps,
                       replay_device="cpu",
                       replay_size=c.replay_size)
     return rainbow
Exemplo n.º 4
0
 def rainbow_vis(self, train_config, device, dtype, tmpdir):
     c = train_config
     tmp_dir = tmpdir.make_numbered_dir()
     q_net = smw(QNet(c.observe_dim, c.action_num)
                 .type(dtype).to(device), device, device)
     q_net_t = smw(QNet(c.observe_dim, c.action_num)
                   .type(dtype).to(device), device, device)
     rainbow = RAINBOW(q_net, q_net_t,
                       t.optim.Adam,
                       c.value_min,
                       c.value_max,
                       reward_future_steps=c.reward_future_steps,
                       replay_device="cpu",
                       replay_size=c.replay_size,
                       visualize=True,
                       visualize_dir=str(tmp_dir))
     return rainbow