コード例 #1
0
ファイル: test_dqn_per.py プロジェクト: iffiX/machin
    def test_config_init(self, train_config):
        c = train_config
        config = DQNPer.generate_config({})
        config["frame_config"]["models"] = ["QNet", "QNet"]
        config["frame_config"]["model_kwargs"] = [{
            "state_dim": c.observe_dim,
            "action_num": c.action_num
        }] * 2
        dqn_per = DQNPer.init_from_config(config)

        old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
        action = t.zeros([1, 1], dtype=t.int)
        dqn_per.store_episode([{
            "state": {
                "state": old_state
            },
            "action": {
                "action": action
            },
            "next_state": {
                "state": state
            },
            "reward": 0,
            "terminal": False,
        } for _ in range(3)])
        dqn_per.update()
コード例 #2
0
ファイル: test_dqn_per.py プロジェクト: lethaiq/machin
 def dqn_per_train(self, train_config):
     c = train_config
     # cpu is faster for testing full training.
     q_net = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     q_net_t = smw(QNet(c.observe_dim, c.action_num), "cpu", "cpu")
     dqn_per = DQNPer(q_net,
                      q_net_t,
                      t.optim.Adam,
                      nn.MSELoss(reduction='sum'),
                      replay_device="cpu",
                      replay_size=c.replay_size)
     return dqn_per
コード例 #3
0
 def dqn_per(self, train_config):
     c = train_config
     q_net = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     q_net_t = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     dqn_per = DQNPer(q_net,
                      q_net_t,
                      t.optim.Adam,
                      nn.MSELoss(reduction='sum'),
                      replay_device="cpu",
                      replay_size=c.replay_size)
     return dqn_per
コード例 #4
0
 def dqn_per_vis(self, train_config, tmpdir):
     c = train_config
     tmp_dir = tmpdir.make_numbered_dir()
     q_net = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     q_net_t = smw(
         QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
     dqn_per = DQNPer(q_net,
                      q_net_t,
                      t.optim.Adam,
                      nn.MSELoss(reduction='sum'),
                      replay_device="cpu",
                      replay_size=c.replay_size,
                      visualize=True,
                      visualize_dir=str(tmp_dir))
     return dqn_per
コード例 #5
0
    def test_criterion(self, train_config):
        c = train_config
        q_net = smw(
            QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
        q_net_t = smw(
            QNet(c.observe_dim, c.action_num).to(c.device), c.device, c.device)
        with pytest.raises(RuntimeError,
                           match="Criterion does not have the "
                           "'reduction' property"):

            def criterion(a, b):
                return a - b

            _ = DQNPer(q_net,
                       q_net_t,
                       t.optim.Adam,
                       criterion,
                       replay_device="cpu",
                       replay_size=c.replay_size,
                       mode="invalid_mode")