def dqn_lr(self, train_config): # not used for training, only used for testing apis c = train_config q_net = smw( QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device) q_net_t = smw( QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device) lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)], logger=logger) with pytest.raises(TypeError, match="missing .+ positional argument"): _ = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), replay_device=c.device, replay_size=c.replay_size, lr_scheduler=LambdaLR) dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), replay_device=c.device, replay_size=c.replay_size, lr_scheduler=LambdaLR, lr_scheduler_args=((lr_func, ), )) return dqn
def test_config_init(self, train_config): c = train_config config = DQN.generate_config({}) config["frame_config"]["models"] = ["QNet", "QNet"] config["frame_config"]["model_kwargs"] = [{ "state_dim": c.observe_dim, "action_num": c.action_num }] * 2 dqn = DQN.init_from_config(config) old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32) action = t.zeros([1, 1], dtype=t.int) dqn.store_episode([{ "state": { "state": old_state }, "action": { "action": action }, "next_state": { "state": state }, "reward": 0, "terminal": False, } for _ in range(3)]) dqn.update()
def dqn_train(self, train_config, request): c = train_config # cpu is faster for testing full training. q_net = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu") q_net_t = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu") dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), replay_device="cpu", replay_size=c.replay_size, mode=request.param) return dqn
def test_mode(self, train_config, device, dtype): c = train_config q_net = smw( QNet(c.observe_dim, c.action_num).type(dtype).to(device), device, device) q_net_t = smw( QNet(c.observe_dim, c.action_num).type(dtype).to(device), device, device) with pytest.raises(ValueError, match="Unknown DQN mode"): _ = DQN( q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum"), replay_device="cpu", replay_size=c.replay_size, mode="invalid_mode", ) with pytest.raises(ValueError, match="Unknown DQN mode"): dqn = DQN( q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum"), replay_device="cpu", replay_size=c.replay_size, mode="double", ) old_state = state = t.zeros([1, c.observe_dim], dtype=dtype) action = t.zeros([1, 1], dtype=t.int) dqn.store_episode([{ "state": { "state": old_state }, "action": { "action": action }, "next_state": { "state": state }, "reward": 0, "terminal": False, } for _ in range(3)]) dqn.mode = "invalid_mode" dqn.update(update_value=True, update_target=True, concatenate_samples=True)
def dqn(self, train_config, device, dtype, request): c = train_config q_net = smw(QNet(c.observe_dim, c.action_num) .type(dtype).to(device), device, device) q_net_t = smw(QNet(c.observe_dim, c.action_num) .type(dtype).to(device), device, device) dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), replay_device="cpu", replay_size=c.replay_size, mode=request.param) return dqn
def dqn_vis(self, train_config, device, dtype, tmpdir, request): c = train_config tmp_dir = tmpdir.make_numbered_dir() q_net = smw(QNet(c.observe_dim, c.action_num) .type(dtype).to(device), device, device) q_net_t = smw(QNet(c.observe_dim, c.action_num) .type(dtype).to(device), device, device) dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), replay_device="cpu", replay_size=c.replay_size, mode=request.param, visualize=True, visualize_dir=str(tmp_dir)) return dqn