def a3c(): c = TestA3C.c actor = smw(Actor(c.observe_dim, c.action_num) .to(c.device), c.device, c.device) critic = smw(Critic(c.observe_dim) .to(c.device), c.device, c.device) # in all test scenarios, all processes will be used as reducers servers = grad_server_helper( [lambda: Actor(c.observe_dim, c.action_num), lambda: Critic(c.observe_dim)], learning_rate=5e-3 ) a3c = A3C(actor, critic, nn.MSELoss(reduction='sum'), servers, replay_device="cpu", replay_size=c.replay_size) return a3c
def init_from_config( cls, config: Union[Dict[str, Any], Config], model_device: Union[str, t.device] = "cpu", ): f_config = deepcopy(config["frame_config"]) model_cls = assert_and_get_valid_models(f_config["models"]) model_args = f_config["model_args"] model_kwargs = f_config["model_kwargs"] models = [ m(*arg, **kwarg).to(model_device) for m, arg, kwarg in zip(model_cls, model_args, model_kwargs) ] model_creators = [ lambda: m(*arg, **kwarg) for m, arg, kwarg in zip(model_cls, model_args, model_kwargs) ] optimizer = assert_and_get_valid_optimizer(f_config["optimizer"]) criterion = assert_and_get_valid_criterion(f_config["criterion"])( *f_config["criterion_args"], **f_config["criterion_kwargs"]) lr_scheduler = f_config[ "lr_scheduler"] and assert_and_get_valid_lr_scheduler( f_config["lr_scheduler"]) servers = grad_server_helper( model_creators, group_name=f_config["grad_server_group_name"], members=f_config["grad_server_members"], optimizer=optimizer, learning_rate=[ f_config["actor_learning_rate"], f_config["critic_learning_rate"], ], lr_scheduler=lr_scheduler, lr_scheduler_args=f_config["lr_scheduler_args"] or ((), ()), lr_scheduler_kwargs=f_config["lr_scheduler_kwargs"] or ({}, {}), ) del f_config["criterion"] frame = cls(*models, criterion, servers, **f_config) return frame
def main(rank): env = gym.make("CartPole-v0") observe_dim = 4 action_num = 2 max_episodes = 2000 max_steps = 200 solved_reward = 190 solved_repeat = 5 # initlize distributed world first _world = World(world_size=3, rank=rank, name=str(rank), rpc_timeout=20) actor = Actor(observe_dim, action_num) critic = Critic(observe_dim) # in all test scenarios, all processes will be used as reducers servers = grad_server_helper( [lambda: Actor(observe_dim, action_num), lambda: Critic(observe_dim)], learning_rate=5e-3 ) a3c = A3C(actor, critic, nn.MSELoss(reduction='sum'), servers) # manually control syncing to improve performance a3c.set_sync(False) # begin training episode, step, reward_fulfilled = 0, 0, 0 smoothed_total_reward = 0 while episode < max_episodes: episode += 1 total_reward = 0 terminal = False step = 0 state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) # manually pull the newest parameters a3c.manual_sync() tmp_observations = [] while not terminal and step <= max_steps: step += 1 with t.no_grad(): old_state = state # agent model inference action = a3c.act({"state": old_state})[0] state, reward, terminal, _ = env.step(action.item()) state = t.tensor(state, dtype=t.float32).view(1, observe_dim) total_reward += reward tmp_observations.append({ "state": {"state": old_state}, "action": {"action": action}, "next_state": {"state": state}, "reward": reward, "terminal": terminal or step == max_steps }) # update a3c.store_episode(tmp_observations) a3c.update() # show reward smoothed_total_reward = (smoothed_total_reward * 0.9 + total_reward * 0.1) logger.info("Process {} Episode {} total reward={:.2f}" .format(rank, episode, smoothed_total_reward)) if smoothed_total_reward > solved_reward: reward_fulfilled += 1 if reward_fulfilled >= solved_repeat: logger.info("Environment solved!") # will cause torch RPC to complain # since other processes may have not finished yet. # just for demonstration. exit(0) else: reward_fulfilled = 0