def sampling_process(common_kwargs, worker_kwargs): """Target function used for forking parallel worker processes in the samplers. After ``initialize_worker()``, it creates the specified number of environment instances and gives them to the collector when instantiating it. It then calls collector startup methods for environments and agent. If applicable, instantiates evaluation environment instances and evaluation collector. Then enters infinite loop, waiting for signals from master to collect training samples or else run evaluation, until signaled to exit. """ c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs) initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads) envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)] set_envs_seeds(envs, w.seed) collector = c.CollectorCls( rank=w.rank, envs=envs, samples_np=w.samples_np, batch_T=c.batch_T, TrajInfoCls=c.TrajInfoCls, agent=c.get("agent", None), # Optional depending on parallel setup. sync=w.get("sync", None), step_buffer_np=w.get("step_buffer_np", None), global_B=c.get("global_B", 1), env_ranks=w.get("env_ranks", None), ) agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps) collector.start_agent() if c.get("eval_n_envs", 0) > 0: eval_envs = [ c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs) ] set_envs_seeds(eval_envs, w.seed) eval_collector = c.eval_CollectorCls( rank=w.rank, envs=eval_envs, TrajInfoCls=c.TrajInfoCls, traj_infos_queue=c.eval_traj_infos_queue, max_T=c.eval_max_T, agent=c.get("agent", None), sync=w.get("sync", None), step_buffer_np=w.get("eval_step_buffer_np", None), ) else: eval_envs = list() ctrl = c.ctrl ctrl.barrier_out.wait() while True: collector.reset_if_needed(agent_inputs) # Outside barrier? ctrl.barrier_in.wait() if ctrl.quit.value: break if ctrl.do_eval.value: # Traj_infos to queue inside. eval_collector.collect_evaluation(ctrl.itr.value) else: agent_inputs, traj_infos, completed_infos = collector.collect_batch( agent_inputs, traj_infos, ctrl.itr.value) for info in completed_infos: c.traj_infos_queue.put(info) ctrl.barrier_out.wait() for env in envs + eval_envs: env.close()
def sampling_process(common_kwargs, worker_kwargs): """Target function used for forking parallel worker processes in the samplers. After ``initialize_worker()``, it creates the specified number of environment instances and gives them to the collector when instantiating it. It then calls collector startup methods for environments and agent. If applicable, instantiates evaluation environment instances and evaluation collector. Then enters infinite loop, waiting for signals from master to collect training samples or else run evaluation, until signaled to exit. """ c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs) initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads) envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)] log_heatmaps = c.env_kwargs.get('log_heatmaps', None) if log_heatmaps is not None and log_heatmaps == True: for env in envs[1:]: env.log_heatmaps = False if c.record_freq > 0: if c.env_kwargs['game'] in ATARI_ENVS: envs[0].record_env = True os.makedirs(os.path.join(c.log_dir, 'videos/frames')) elif c.get( "eval_n_envs", 0 ) == 0: # only record workers if no evaluation processes are performed envs[0] = Monitor(envs[0], c.log_dir + '/videos', video_callable=lambda episode_id: episode_id % c. record_freq == 0) set_envs_seeds(envs, w.seed) collector = c.CollectorCls( rank=w.rank, envs=envs, samples_np=w.samples_np, batch_T=c.batch_T, TrajInfoCls=c.TrajInfoCls, agent=c.get("agent", None), # Optional depending on parallel setup. sync=w.get("sync", None), step_buffer_np=w.get("step_buffer_np", None), global_B=c.get("global_B", 1), env_ranks=w.get("env_ranks", None), no_extrinsic=c.no_extrinsic) agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps) collector.start_agent() if c.get("eval_n_envs", 0) > 0: eval_envs = [ c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs) ] if c.record_freq > 0: eval_envs[0] = Monitor(eval_envs[0], c.log_dir + '/videos', video_callable=lambda episode_id: episode_id % c.record_freq == 0) set_envs_seeds(eval_envs, w.seed) eval_collector = c.eval_CollectorCls( rank=w.rank, envs=eval_envs, TrajInfoCls=c.TrajInfoCls, traj_infos_queue=c.eval_traj_infos_queue, max_T=c.eval_max_T, agent=c.get("agent", None), sync=w.get("sync", None), step_buffer_np=w.get("eval_step_buffer_np", None), ) else: eval_envs = list() ctrl = c.ctrl ctrl.barrier_out.wait() while True: collector.reset_if_needed(agent_inputs) # Outside barrier? ctrl.barrier_in.wait() if ctrl.quit.value: logger.log('Quitting worker ...') break if ctrl.do_eval.value: eval_collector.collect_evaluation( ctrl.itr.value) # Traj_infos to queue inside. else: agent_inputs, traj_infos, completed_infos = collector.collect_batch( agent_inputs, traj_infos, ctrl.itr.value) for info in completed_infos: c.traj_infos_queue.put(info) ctrl.barrier_out.wait() for env in envs + eval_envs: logger.log('Stopping env ...') env.close()
def initialize( self, agent, affinity=None, seed=None, bootstrap_value=False, traj_info_kwargs=None, rank=0, world_size=1, ): """Store the input arguments. Instantiate the specified number of environment instances (``batch_B``). Initialize the agent, and pre-allocate a memory buffer to hold the samples collected in each batch. Applies ``traj_info_kwargs`` settings to the `TrajInfoCls` by direct class attribute assignment. Instantiates the Collector and, if applicable, the evaluation Collector. Returns a structure of inidividual examples for data fields such as `observation`, `action`, etc, which can be used to allocate a replay buffer. """ B = self.batch_spec.B envs = [self.EnvCls(**self.env_kwargs) for _ in range(B)] set_envs_seeds(envs, seed) # Random seed made in runner. global_B = B * world_size env_ranks = list(range(rank * B, (rank + 1) * B)) agent.observer.initialize(envs[0].spaces()[1], share_memory=False, global_B=global_B, env_ranks=env_ranks) agent.player.initialize(envs[0].spaces()[0], share_memory=False, global_B=global_B, env_ranks=env_ranks) observer_samples_pyt, observer_samples_np, observer_examples = build_samples_buffer(agent.observer, envs[0], self.batch_spec, bootstrap_value, agent_shared=False, env_shared=False, subprocess=False) player_samples_pyt, player_samples_np, player_examples = build_samples_buffer(agent.player, envs[0], self.batch_spec, bootstrap_value, agent_shared=False, env_shared=False, subprocess=False) if traj_info_kwargs: for k, v in traj_info_kwargs.items(): setattr(self.TrajInfoCls, "_" + k, v) # Avoid passing at init. collector = self.CollectorCls( rank=0, envs=envs, player_samples_np=player_samples_np, observer_samples_np=observer_samples_np, batch_T=self.batch_spec.T, TrajInfoCls=self.TrajInfoCls, agent=agent, global_B=global_B, env_ranks=env_ranks, # Might get applied redundantly to agent. ) if self.eval_n_envs > 0: # May do evaluation. eval_envs = [self.EnvCls(**self.eval_env_kwargs) for _ in range(self.eval_n_envs)] set_envs_seeds(eval_envs, seed) eval_CollectorCls = self.eval_CollectorCls or SerialEvalCollector self.eval_collector = eval_CollectorCls( envs=eval_envs, agent=agent, TrajInfoCls=self.TrajInfoCls, max_T=self.eval_max_steps // self.eval_n_envs, max_trajectories=self.eval_max_trajectories, ) player_agent_inputs, player_traj_infos, observer_agent_inputs, observer_traj_infos = collector.start_envs( self.max_decorrelation_steps) collector.start_agent() self.agent = agent self.player_samples_pyt = player_samples_pyt self.player_samples_np = player_samples_np self.observer_samples_pyt = observer_samples_pyt self.observer_samples_np = observer_samples_np self.collector = collector self.player_agent_inputs = player_agent_inputs self.player_traj_infos = player_traj_infos self.observer_agent_inputs = observer_agent_inputs self.observer_traj_infos = observer_traj_infos logger.log("Serial Sampler initialized.") return player_examples, observer_examples
def build_and_train(game="cartpole", run_ID=0, cuda_idx=None, sample_mode="serial", n_parallel=2, eval=False, serial=False, train_mask=[True, True], wandb_log=False, save_models_to_wandb=False, log_interval_steps=1e5, observation_mode="agent", inc_player_last_act=False, alt_train=False, eval_perf=False, n_steps=50e6, one_agent=False): # def envs: if observation_mode == "agent": fully_obs = False rand_obs = False elif observation_mode == "random": fully_obs = False rand_obs = True elif observation_mode == "full": fully_obs = True rand_obs = False n_serial = None if game == "cartpole": work_env = gym.make env_name = 'CartPole-v1' cont_act = False state_space_low = np.asarray([ 0.0, 0.0, 0.0, 0.0, -4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38 ]) state_space_high = np.asarray([ 1.0, 1.0, 1.0, 1.0, 4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38 ]) obs_space = Box(state_space_low, state_space_high, dtype=np.float32) player_act_space = work_env(env_name).action_space player_act_space.shape = (1, ) print(player_act_space) if inc_player_last_act: observer_obs_space = Box(np.append(state_space_low, 0), np.append(state_space_high, 1), dtype=np.float32) else: observer_obs_space = obs_space player_reward_shaping = player_reward_shaping_cartpole observer_reward_shaping = observer_reward_shaping_cartpole max_decor_steps = 20 b_size = 20 num_envs = 8 max_episode_length = np.inf player_model_kwargs = dict(hidden_sizes=[24], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) observer_model_kwargs = dict(hidden_sizes=[64], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) elif game == "hiv": work_env = wn.gym.make env_name = 'HIV-v0' cont_act = False state_space_low = np.asarray( [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) state_space_high = np.asarray([ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf ]) obs_space = Box(state_space_low, state_space_high, dtype=np.float32) player_act_space = work_env(env_name).action_space if inc_player_last_act: observer_obs_space = Box(np.append(state_space_low, 0), np.append(state_space_high, 3), dtype=np.float32) else: observer_obs_space = obs_space player_reward_shaping = player_reward_shaping_hiv observer_reward_shaping = observer_reward_shaping_hiv max_decor_steps = 10 b_size = 32 num_envs = 8 max_episode_length = 100 player_model_kwargs = dict(hidden_sizes=[32], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) observer_model_kwargs = dict(hidden_sizes=[64], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) elif game == "heparin": work_env = HeparinEnv env_name = 'Heparin-Simulator' cont_act = False state_space_low = np.asarray([ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 18728.926, 72.84662, 0.0, 0.0, 0.0, 0.0, 0.0 ]) state_space_high = np.asarray([ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.7251439e+04, 1.0664291e+02, 200.0, 8.9383472e+02, 1.0025734e+02, 1.5770737e+01, 4.7767456e+01 ]) # state_space_low = np.asarray([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18728.926,72.84662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]) # state_space_high = np.asarray([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.7251439e+04,1.0664291e+02,0.0000000e+00,8.9383472e+02,1.4476662e+02,1.3368750e+02,1.6815166e+02,1.0025734e+02,1.5770737e+01,4.7767456e+01,7.7194958e+00]) obs_space = Box(state_space_low, state_space_high, dtype=np.float32) player_act_space = work_env(env_name).action_space if inc_player_last_act: observer_obs_space = Box(np.append(state_space_low, 0), np.append(state_space_high, 4), dtype=np.float32) else: observer_obs_space = obs_space player_reward_shaping = player_reward_shaping_hep observer_reward_shaping = observer_reward_shaping_hep max_decor_steps = 3 b_size = 20 num_envs = 8 max_episode_length = 20 player_model_kwargs = dict(hidden_sizes=[32], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) observer_model_kwargs = dict(hidden_sizes=[128], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) elif game == "halfcheetah": assert not serial assert not one_agent work_env = gym.make env_name = 'HalfCheetah-v2' cont_act = True temp_env = work_env(env_name) state_space_low = np.concatenate([ np.zeros(temp_env.observation_space.low.shape), temp_env.observation_space.low ]) state_space_high = np.concatenate([ np.ones(temp_env.observation_space.high.shape), temp_env.observation_space.high ]) obs_space = Box(state_space_low, state_space_high, dtype=np.float32) player_act_space = temp_env.action_space if inc_player_last_act: observer_obs_space = Box(np.append(state_space_low, 0), np.append(state_space_high, 4), dtype=np.float32) else: observer_obs_space = obs_space player_reward_shaping = None observer_reward_shaping = None temp_env.close() max_decor_steps = 0 b_size = 20 num_envs = 8 max_episode_length = np.inf player_model_kwargs = dict(hidden_sizes=[256, 256]) observer_model_kwargs = dict(hidden_sizes=[256, 256]) player_q_model_kwargs = dict(hidden_sizes=[256, 256]) observer_q_model_kwargs = dict(hidden_sizes=[256, 256]) player_v_model_kwargs = dict(hidden_sizes=[256, 256]) observer_v_model_kwargs = dict(hidden_sizes=[256, 256]) if game == "halfcheetah": observer_act_space = Box( low=state_space_low[:int(len(state_space_low) / 2)], high=state_space_high[:int(len(state_space_high) / 2)]) else: if serial: n_serial = int(len(state_space_high) / 2) observer_act_space = Discrete(2) observer_act_space.shape = (1, ) else: if one_agent: observer_act_space = IntBox( low=0, high=player_act_space.n * int(2**int(len(state_space_high) / 2))) else: observer_act_space = IntBox(low=0, high=int(2**int( len(state_space_high) / 2))) affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))) gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}" if sample_mode == "serial": alt = False Sampler = SerialSampler # (Ignores workers_cpus.) if eval: eval_collector_cl = SerialEvalCollector else: eval_collector_cl = None print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.") elif sample_mode == "cpu": alt = False Sampler = CpuSampler if eval: eval_collector_cl = CpuEvalCollector else: eval_collector_cl = None print( f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing." ) env_kwargs = dict(work_env=work_env, env_name=env_name, obs_spaces=[obs_space, observer_obs_space], action_spaces=[player_act_space, observer_act_space], serial=serial, player_reward_shaping=player_reward_shaping, observer_reward_shaping=observer_reward_shaping, fully_obs=fully_obs, rand_obs=rand_obs, inc_player_last_act=inc_player_last_act, max_episode_length=max_episode_length, cont_act=cont_act) if eval: eval_env_kwargs = env_kwargs eval_max_steps = 1e4 num_eval_envs = num_envs else: eval_env_kwargs = None eval_max_steps = None num_eval_envs = 0 sampler = Sampler( EnvCls=CWTO_EnvWrapper, env_kwargs=env_kwargs, batch_T=b_size, batch_B=num_envs, max_decorrelation_steps=max_decor_steps, eval_n_envs=num_eval_envs, eval_CollectorCls=eval_collector_cl, eval_env_kwargs=eval_env_kwargs, eval_max_steps=eval_max_steps, ) if game == "halfcheetah": player_algo = SAC() observer_algo = SACBeta() player = SacAgent(ModelCls=PiMlpModel, QModelCls=QofMuMlpModel, model_kwargs=player_model_kwargs, q_model_kwargs=player_q_model_kwargs, v_model_kwargs=player_v_model_kwargs) observer = SacAgentBeta(ModelCls=PiMlpModelBeta, QModelCls=QofMuMlpModel, model_kwargs=observer_model_kwargs, q_model_kwargs=observer_q_model_kwargs, v_model_kwargs=observer_v_model_kwargs) else: player_model = CWTO_LstmModel observer_model = CWTO_LstmModel player_algo = PPO() observer_algo = PPO() player = CWTO_LstmAgent(ModelCls=player_model, model_kwargs=player_model_kwargs, initial_model_state_dict=None) observer = CWTO_LstmAgent(ModelCls=observer_model, model_kwargs=observer_model_kwargs, initial_model_state_dict=None) if one_agent: agent = CWTO_AgentWrapper(player, observer, serial=serial, n_serial=n_serial, alt=alt, train_mask=train_mask, one_agent=one_agent, nplayeract=player_act_space.n) else: agent = CWTO_AgentWrapper(player, observer, serial=serial, n_serial=n_serial, alt=alt, train_mask=train_mask) if eval: RunnerCl = MinibatchRlEval else: RunnerCl = MinibatchRl runner = RunnerCl(player_algo=player_algo, observer_algo=observer_algo, agent=agent, sampler=sampler, n_steps=n_steps, log_interval_steps=log_interval_steps, affinity=affinity, wandb_log=wandb_log, alt_train=alt_train) config = dict(domain=game) if game == "halfcheetah": name = "sac_" + game else: name = "ppo_" + game log_dir = os.getcwd() + "/cwto_logs/" + name with logger_context(log_dir, run_ID, name, config): runner.train() if save_models_to_wandb: agent.save_models_to_wandb() if eval_perf: eval_n_envs = 10 eval_envs = [CWTO_EnvWrapper(**env_kwargs) for _ in range(eval_n_envs)] set_envs_seeds(eval_envs, make_seed()) eval_collector = SerialEvalCollector(envs=eval_envs, agent=agent, TrajInfoCls=TrajInfo_obs, max_T=1000, max_trajectories=10, log_full_obs=True) traj_infos_player, traj_infos_observer = eval_collector.collect_evaluation( runner.get_n_itr()) observations = [] player_actions = [] returns = [] observer_actions = [] for traj in traj_infos_player: observations.append(np.array(traj.Observations)) player_actions.append(np.array(traj.Actions)) returns.append(traj.Return) for traj in traj_infos_observer: observer_actions.append( np.array([ obs_action_translator(act, eval_envs[0].power_vec, eval_envs[0].obs_size) for act in traj.Actions ])) # save results: open_obs = open('eval_observations.pkl', "wb") pickle.dump(observations, open_obs) open_obs.close() open_ret = open('eval_returns.pkl', "wb") pickle.dump(returns, open_ret) open_ret.close() open_pact = open('eval_player_actions.pkl', "wb") pickle.dump(player_actions, open_pact) open_pact.close() open_oact = open('eval_observer_actions.pkl', "wb") pickle.dump(observer_actions, open_oact) open_oact.close()