def test_config_init(rank): c = TestDDPGApex.c config = DDPGApex.generate_config({}) config["frame_config"]["models"] = ["Actor", "Actor", "Critic", "Critic"] config["frame_config"]["model_kwargs"] = [ { "state_dim": c.observe_dim, "action_dim": c.action_dim, "action_range": c.action_range, } ] * 2 + [{"state_dim": c.observe_dim, "action_dim": c.action_dim}] * 2 ddpg_apex = DDPGApex.init_from_config(config) old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32) action = t.zeros([1, c.action_dim], dtype=t.float32) if rank in (1, 2): ddpg_apex.store_episode( [ { "state": {"state": old_state}, "action": {"action": action}, "next_state": {"state": state}, "reward": 0, "terminal": False, } for _ in range(3) ] ) sleep(5) ddpg_apex.manual_sync() if rank == 0: sleep(2) ddpg_apex.update() return True
def ddpg_apex(device, dtype, discrete=False): c = TestDDPGApex.c if not discrete: actor = smw(Actor(c.observe_dim, c.action_dim, c.action_range) .type(dtype).to(device), device, device) actor_t = smw(Actor(c.observe_dim, c.action_dim, c.action_range) .type(dtype).to(device), device, device) else: actor = smw(ActorDiscrete(c.observe_dim, c.action_dim) .type(dtype).to(device), device, device) actor_t = smw(ActorDiscrete(c.observe_dim, c.action_dim) .type(dtype).to(device), device, device) critic = smw(Critic(c.observe_dim, c.action_dim) .type(dtype).to(device), device, device) critic_t = smw(Critic(c.observe_dim, c.action_dim) .type(dtype).to(device), device, device) servers = model_server_helper(model_num=2) world = get_world() # process 0 and 1 will be workers, and 2 will be trainer apex_group = world.create_rpc_group("worker", ["0", "1", "2"]) ddpg_apex = DDPGApex(actor, actor_t, critic, critic_t, t.optim.Adam, nn.MSELoss(reduction='sum'), apex_group, servers, replay_device="cpu", replay_size=c.replay_size) return ddpg_apex