コード例 #1
0
ファイル: test_a2c.py プロジェクト: yueweizhizhu/machin
 def a2c_lr(self, train_config):
     # not used for training, only used for testing apis
     c = train_config
     actor = smw(
         Actor(c.observe_dim, c.action_num).to(c.device), c.device,
         c.device)
     critic = smw(Critic(c.observe_dim).to(c.device), c.device, c.device)
     lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)],
                                      logger=logger)
     with pytest.raises(TypeError, match="missing .+ positional argument"):
         _ = A2C(actor,
                 critic,
                 t.optim.Adam,
                 nn.MSELoss(reduction='sum'),
                 replay_device=c.device,
                 replay_size=c.replay_size,
                 lr_scheduler=LambdaLR)
     a2c = A2C(actor,
               critic,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device=c.device,
               replay_size=c.replay_size,
               lr_scheduler=LambdaLR,
               lr_scheduler_args=((lr_func, ), (lr_func, )))
     return a2c
コード例 #2
0
 def test_config_init(self, train_config):
     c = train_config
     config = A2C.generate_config({})
     config["frame_config"]["models"] = ["Actor", "Critic"]
     config["frame_config"]["model_kwargs"] = [
         {
             "state_dim": c.observe_dim,
             "action_num": c.action_num
         },
         {
             "state_dim": c.observe_dim
         },
     ]
     a2c = A2C.init_from_config(config)
     old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
     action = t.zeros([1, 1], dtype=t.int)
     a2c.store_episode([{
         "state": {
             "state": old_state
         },
         "action": {
             "action": action
         },
         "next_state": {
             "state": state
         },
         "reward": 0,
         "terminal": False,
     } for _ in range(3)])
     a2c.update()
コード例 #3
0
ファイル: test_a2c.py プロジェクト: lethaiq/machin
 def a2c_train(self, train_config):
     c = train_config
     # cpu is faster for testing full training.
     actor = smw(Actor(c.observe_dim, c.action_num), "cpu", "cpu")
     critic = smw(Critic(c.observe_dim), "cpu", "cpu")
     a2c = A2C(actor,
               critic,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device="cpu",
               replay_size=c.replay_size)
     return a2c
コード例 #4
0
 def a2c(self, train_config):
     c = train_config
     actor = smw(Actor(c.observe_dim, c.action_num)
                 .to(c.device), c.device, c.device)
     critic = smw(Critic(c.observe_dim)
                  .to(c.device), c.device, c.device)
     a2c = A2C(actor, critic,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device=c.device,
               replay_size=c.replay_size)
     return a2c
コード例 #5
0
ファイル: test_a2c.py プロジェクト: ikamensh/machin
 def a2c(self, train_config, device, dtype):
     c = train_config
     actor = smw(
         Actor(c.observe_dim, c.action_num).type(dtype).to(device), device, device
     )
     critic = smw(Critic(c.observe_dim).type(dtype).to(device), device, device)
     a2c = A2C(
         actor,
         critic,
         t.optim.Adam,
         nn.MSELoss(reduction="sum"),
         replay_device="cpu",
         replay_size=c.replay_size,
     )
     return a2c
コード例 #6
0
 def a2c_vis(self, train_config, tmpdir):
     # not used for training, only used for testing apis
     c = train_config
     tmp_dir = tmpdir.make_numbered_dir()
     actor = smw(Actor(c.observe_dim, c.action_num)
                 .to(c.device), c.device, c.device)
     critic = smw(Critic(c.observe_dim)
                  .to(c.device), c.device, c.device)
     a2c = A2C(actor, critic,
               t.optim.Adam,
               nn.MSELoss(reduction='sum'),
               replay_device=c.device,
               replay_size=c.replay_size,
               visualize=True,
               visualize_dir=str(tmp_dir))
     return a2c