def impala(device, dtype, use_lr_sch=False): c = TestIMPALA.c actor = smw( Actor(c.observe_dim, c.action_num).type(dtype).to(device), device, device) critic = smw( Critic(c.observe_dim).type(dtype).to(device), device, device) servers = model_server_helper(model_num=1) world = get_world() # process 0 and 1 will be workers, and 2 will be trainer impala_group = world.create_rpc_group("impala", ["0", "1", "2"]) if use_lr_sch: lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)], logger=default_logger) impala = IMPALA(actor, critic, t.optim.Adam, nn.MSELoss(reduction='sum'), impala_group, servers, lr_scheduler=LambdaLR, lr_scheduler_args=((lr_func, ), (lr_func, ))) else: impala = IMPALA(actor, critic, t.optim.Adam, nn.MSELoss(reduction='sum'), impala_group, servers) return impala
def test_config_init(rank): c = TestIMPALA.c config = IMPALA.generate_config({}) config["frame_config"]["models"] = ["Actor", "Critic"] config["frame_config"]["model_kwargs"] = [ { "state_dim": c.observe_dim, "action_num": c.action_num }, { "state_dim": c.observe_dim }, ] impala = IMPALA.init_from_config(config) old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32) action = t.zeros([1, 1], dtype=t.int) if rank == 0: # episode length = 3 impala.store_episode([{ "state": { "state": old_state }, "action": { "action": action }, "next_state": { "state": state }, "reward": 0, "action_log_prob": 0.1, "terminal": False, } for _ in range(3)]) elif rank == 1: # episode length = 2 impala.store_episode([{ "state": { "state": old_state }, "action": { "action": action }, "next_state": { "state": state }, "reward": 0, "action_log_prob": 0.1, "terminal": False, } for _ in range(2)]) if rank == 2: sleep(2) impala.update(update_value=True, update_target=True, concatenate_samples=True) return True