コード例 #1
0
 def dqn_lr(self, train_config):
     # not used for training, only used for testing apis
     c = train_config
     q_net = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     q_net_t = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)],
                                      logger=logger)
     with pytest.raises(TypeError, match="missing .+ positional argument"):
         _ = DQN(q_net,
                 q_net_t,
                 t.optim.Adam,
                 nn.MSELoss(reduction='sum'),
                 replay_device=c.device,
                 replay_size=c.replay_size,
                 lr_scheduler=LambdaLR)
     dqn = DQN(q_net,
               q_net_t,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device=c.device,
               replay_size=c.replay_size,
               lr_scheduler=LambdaLR,
               lr_scheduler_args=((lr_func, ), ))
     return dqn
コード例 #2
0
ファイル: test_dqn.py プロジェクト: ikamensh/machin
    def test_config_init(self, train_config):
        c = train_config
        config = DQN.generate_config({})
        config["frame_config"]["models"] = ["QNet", "QNet"]
        config["frame_config"]["model_kwargs"] = [{
            "state_dim": c.observe_dim,
            "action_num": c.action_num
        }] * 2
        dqn = DQN.init_from_config(config)

        old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
        action = t.zeros([1, 1], dtype=t.int)
        dqn.store_episode([{
            "state": {
                "state": old_state
            },
            "action": {
                "action": action
            },
            "next_state": {
                "state": state
            },
            "reward": 0,
            "terminal": False,
        } for _ in range(3)])
        dqn.update()
コード例 #3
0
 def dqn_train(self, train_config, request):
     c = train_config
     # cpu is faster for testing full training.
     q_net = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     q_net_t = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     dqn = DQN(q_net, q_net_t,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device="cpu",
               replay_size=c.replay_size,
               mode=request.param)
     return dqn
コード例 #4
0
ファイル: test_dqn.py プロジェクト: ikamensh/machin
    def test_mode(self, train_config, device, dtype):
        c = train_config
        q_net = smw(
            QNet(c.observe_dim, c.action_num).type(dtype).to(device), device,
            device)
        q_net_t = smw(
            QNet(c.observe_dim, c.action_num).type(dtype).to(device), device,
            device)

        with pytest.raises(ValueError, match="Unknown DQN mode"):
            _ = DQN(
                q_net,
                q_net_t,
                t.optim.Adam,
                nn.MSELoss(reduction="sum"),
                replay_device="cpu",
                replay_size=c.replay_size,
                mode="invalid_mode",
            )

        with pytest.raises(ValueError, match="Unknown DQN mode"):
            dqn = DQN(
                q_net,
                q_net_t,
                t.optim.Adam,
                nn.MSELoss(reduction="sum"),
                replay_device="cpu",
                replay_size=c.replay_size,
                mode="double",
            )

            old_state = state = t.zeros([1, c.observe_dim], dtype=dtype)
            action = t.zeros([1, 1], dtype=t.int)
            dqn.store_episode([{
                "state": {
                    "state": old_state
                },
                "action": {
                    "action": action
                },
                "next_state": {
                    "state": state
                },
                "reward": 0,
                "terminal": False,
            } for _ in range(3)])
            dqn.mode = "invalid_mode"
            dqn.update(update_value=True,
                       update_target=True,
                       concatenate_samples=True)
コード例 #5
0
 def dqn(self, train_config, device, dtype, request):
     c = train_config
     q_net = smw(QNet(c.observe_dim, c.action_num)
                 .type(dtype).to(device), device, device)
     q_net_t = smw(QNet(c.observe_dim, c.action_num)
                   .type(dtype).to(device), device, device)
     dqn = DQN(q_net, q_net_t,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device="cpu",
               replay_size=c.replay_size,
               mode=request.param)
     return dqn
コード例 #6
0
 def dqn_vis(self, train_config, device, dtype, tmpdir, request):
     c = train_config
     tmp_dir = tmpdir.make_numbered_dir()
     q_net = smw(QNet(c.observe_dim, c.action_num)
                 .type(dtype).to(device), device, device)
     q_net_t = smw(QNet(c.observe_dim, c.action_num)
                   .type(dtype).to(device), device, device)
     dqn = DQN(q_net, q_net_t,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device="cpu",
               replay_size=c.replay_size,
               mode=request.param,
               visualize=True,
               visualize_dir=str(tmp_dir))
     return dqn