Beispiel #1
0
    def test_config_init(rank):
        c = TestDDPGApex.c
        config = DDPGApex.generate_config({})
        config["frame_config"]["models"] = ["Actor", "Actor", "Critic", "Critic"]
        config["frame_config"]["model_kwargs"] = [
            {
                "state_dim": c.observe_dim,
                "action_dim": c.action_dim,
                "action_range": c.action_range,
            }
        ] * 2 + [{"state_dim": c.observe_dim, "action_dim": c.action_dim}] * 2
        ddpg_apex = DDPGApex.init_from_config(config)

        old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
        action = t.zeros([1, c.action_dim], dtype=t.float32)
        if rank in (1, 2):
            ddpg_apex.store_episode(
                [
                    {
                        "state": {"state": old_state},
                        "action": {"action": action},
                        "next_state": {"state": state},
                        "reward": 0,
                        "terminal": False,
                    }
                    for _ in range(3)
                ]
            )
            sleep(5)
            ddpg_apex.manual_sync()
        if rank == 0:
            sleep(2)
            ddpg_apex.update()

        return True
Beispiel #2
0
    def ddpg_apex(device, dtype, discrete=False):
        c = TestDDPGApex.c
        if not discrete:
            actor = smw(Actor(c.observe_dim, c.action_dim, c.action_range)
                        .type(dtype).to(device), device, device)
            actor_t = smw(Actor(c.observe_dim, c.action_dim, c.action_range)
                          .type(dtype).to(device), device, device)
        else:
            actor = smw(ActorDiscrete(c.observe_dim, c.action_dim)
                        .type(dtype).to(device), device, device)
            actor_t = smw(ActorDiscrete(c.observe_dim, c.action_dim)
                          .type(dtype).to(device), device, device)
        critic = smw(Critic(c.observe_dim, c.action_dim)
                     .type(dtype).to(device), device, device)
        critic_t = smw(Critic(c.observe_dim, c.action_dim)
                       .type(dtype).to(device), device, device)

        servers = model_server_helper(model_num=2)
        world = get_world()
        # process 0 and 1 will be workers, and 2 will be trainer
        apex_group = world.create_rpc_group("worker", ["0", "1", "2"])
        ddpg_apex = DDPGApex(actor, actor_t, critic, critic_t,
                             t.optim.Adam,
                             nn.MSELoss(reduction='sum'),
                             apex_group,
                             servers,
                             replay_device="cpu",
                             replay_size=c.replay_size)
        return ddpg_apex