Ejemplo n.º 1
0
    def impala(device, dtype, use_lr_sch=False):
        c = TestIMPALA.c
        actor = smw(
            Actor(c.observe_dim, c.action_num).type(dtype).to(device), device,
            device)
        critic = smw(
            Critic(c.observe_dim).type(dtype).to(device), device, device)
        servers = model_server_helper(model_num=1)
        world = get_world()
        # process 0 and 1 will be workers, and 2 will be trainer
        impala_group = world.create_rpc_group("impala", ["0", "1", "2"])

        if use_lr_sch:
            lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)],
                                             logger=default_logger)
            impala = IMPALA(actor,
                            critic,
                            t.optim.Adam,
                            nn.MSELoss(reduction='sum'),
                            impala_group,
                            servers,
                            lr_scheduler=LambdaLR,
                            lr_scheduler_args=((lr_func, ), (lr_func, )))
        else:
            impala = IMPALA(actor, critic, t.optim.Adam,
                            nn.MSELoss(reduction='sum'), impala_group, servers)
        return impala
Ejemplo n.º 2
0
    def test_config_init(rank):
        c = TestIMPALA.c
        config = IMPALA.generate_config({})
        config["frame_config"]["models"] = ["Actor", "Critic"]
        config["frame_config"]["model_kwargs"] = [
            {
                "state_dim": c.observe_dim,
                "action_num": c.action_num
            },
            {
                "state_dim": c.observe_dim
            },
        ]
        impala = IMPALA.init_from_config(config)

        old_state = state = t.zeros([1, c.observe_dim], dtype=t.float32)
        action = t.zeros([1, 1], dtype=t.int)

        if rank == 0:
            # episode length = 3
            impala.store_episode([{
                "state": {
                    "state": old_state
                },
                "action": {
                    "action": action
                },
                "next_state": {
                    "state": state
                },
                "reward": 0,
                "action_log_prob": 0.1,
                "terminal": False,
            } for _ in range(3)])
        elif rank == 1:
            # episode length = 2
            impala.store_episode([{
                "state": {
                    "state": old_state
                },
                "action": {
                    "action": action
                },
                "next_state": {
                    "state": state
                },
                "reward": 0,
                "action_log_prob": 0.1,
                "terminal": False,
            } for _ in range(2)])
        if rank == 2:
            sleep(2)
            impala.update(update_value=True,
                          update_target=True,
                          concatenate_samples=True)
        return True