Ejemplo n.º 1
0
 def a3c():
     c = TestA3C.c
     actor = smw(Actor(c.observe_dim, c.action_num)
                 .to(c.device), c.device, c.device)
     critic = smw(Critic(c.observe_dim)
                  .to(c.device), c.device, c.device)
     # in all test scenarios, all processes will be used as reducers
     servers = grad_server_helper(
         [lambda: Actor(c.observe_dim, c.action_num),
          lambda: Critic(c.observe_dim)],
         learning_rate=5e-3
     )
     a3c = A3C(actor, critic,
               nn.MSELoss(reduction='sum'),
               servers,
               replay_device="cpu",
               replay_size=c.replay_size)
     return a3c
Ejemplo n.º 2
0
    def init_from_config(
        cls,
        config: Union[Dict[str, Any], Config],
        model_device: Union[str, t.device] = "cpu",
    ):
        f_config = deepcopy(config["frame_config"])
        model_cls = assert_and_get_valid_models(f_config["models"])
        model_args = f_config["model_args"]
        model_kwargs = f_config["model_kwargs"]
        models = [
            m(*arg, **kwarg).to(model_device)
            for m, arg, kwarg in zip(model_cls, model_args, model_kwargs)
        ]
        model_creators = [
            lambda: m(*arg, **kwarg)
            for m, arg, kwarg in zip(model_cls, model_args, model_kwargs)
        ]
        optimizer = assert_and_get_valid_optimizer(f_config["optimizer"])
        criterion = assert_and_get_valid_criterion(f_config["criterion"])(
            *f_config["criterion_args"], **f_config["criterion_kwargs"])
        lr_scheduler = f_config[
            "lr_scheduler"] and assert_and_get_valid_lr_scheduler(
                f_config["lr_scheduler"])

        servers = grad_server_helper(
            model_creators,
            group_name=f_config["grad_server_group_name"],
            members=f_config["grad_server_members"],
            optimizer=optimizer,
            learning_rate=[
                f_config["actor_learning_rate"],
                f_config["critic_learning_rate"],
            ],
            lr_scheduler=lr_scheduler,
            lr_scheduler_args=f_config["lr_scheduler_args"] or ((), ()),
            lr_scheduler_kwargs=f_config["lr_scheduler_kwargs"] or ({}, {}),
        )
        del f_config["criterion"]
        frame = cls(*models, criterion, servers, **f_config)
        return frame
Ejemplo n.º 3
0
def main(rank):
    env = gym.make("CartPole-v0")
    observe_dim = 4
    action_num = 2
    max_episodes = 2000
    max_steps = 200
    solved_reward = 190
    solved_repeat = 5

    # initlize distributed world first
    _world = World(world_size=3, rank=rank,
                   name=str(rank), rpc_timeout=20)

    actor = Actor(observe_dim, action_num)
    critic = Critic(observe_dim)

    # in all test scenarios, all processes will be used as reducers
    servers = grad_server_helper(
        [lambda: Actor(observe_dim, action_num),
         lambda: Critic(observe_dim)],
        learning_rate=5e-3
    )
    a3c = A3C(actor, critic,
              nn.MSELoss(reduction='sum'),
              servers)

    # manually control syncing to improve performance
    a3c.set_sync(False)

    # begin training
    episode, step, reward_fulfilled = 0, 0, 0
    smoothed_total_reward = 0

    while episode < max_episodes:
        episode += 1
        total_reward = 0
        terminal = False
        step = 0

        state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim)

        # manually pull the newest parameters
        a3c.manual_sync()
        tmp_observations = []
        while not terminal and step <= max_steps:
            step += 1
            with t.no_grad():
                old_state = state
                # agent model inference
                action = a3c.act({"state": old_state})[0]
                state, reward, terminal, _ = env.step(action.item())
                state = t.tensor(state, dtype=t.float32).view(1, observe_dim)
                total_reward += reward

                tmp_observations.append({
                    "state": {"state": old_state},
                    "action": {"action": action},
                    "next_state": {"state": state},
                    "reward": reward,
                    "terminal": terminal or step == max_steps
                })

        # update
        a3c.store_episode(tmp_observations)
        a3c.update()

        # show reward
        smoothed_total_reward = (smoothed_total_reward * 0.9 +
                                 total_reward * 0.1)
        logger.info("Process {} Episode {} total reward={:.2f}"
                    .format(rank, episode, smoothed_total_reward))

        if smoothed_total_reward > solved_reward:
            reward_fulfilled += 1
            if reward_fulfilled >= solved_repeat:
                logger.info("Environment solved!")
                # will cause torch RPC to complain
                # since other processes may have not finished yet.
                # just for demonstration.
                exit(0)
        else:
            reward_fulfilled = 0