def train(env_id, num_timesteps, seed): """ Train PPO2 model for Mujoco environment, for testing purposes :param env_id: (str) the environment id string :param num_timesteps: (int) the number of timesteps to run :param seed: (int) Used to seed the random generator. """ def make_env(): env_out = gym.make(env_id) env_out = Monitor(env_out, logger.get_dir(), allow_early_resets=True) return env_out env = DummyVecEnv([make_env]) env = VecNormalize(env) set_global_seeds(seed) policy = MlpPolicy model = PPO2(policy=policy, env=env, n_steps=2048, nminibatches=32, lam=0.95, gamma=0.99, noptepochs=10, ent_coef=0.0, learning_rate=3e-4, cliprange=0.2) model.learn(total_timesteps=num_timesteps) return model, env
def __init__(self, inp, callback): # self.inp.gen_dict ---> GENERAL CARD: env, env_data, nactions, ...etc # self.inp.acer_dict ---> ACER CARD self.inp = inp # the full user input dictionary self.callback = callback self.mode = self.inp.acer_dict['mode'][0] self.log_dir = self.inp.gen_dict['log_dir'] set_global_seeds(3)
def __init__(self, inp, callback): """ Input: inp: is a dictionary of validated user input {"ncores": 8, "env": 6x6, ...} callback: a class of callback built from stable-baselines to allow intervening during training to process data and save models """ # self.inp.gen_dict ---> GENERAL CARD: env, env_data, nactions, ...etc # self.inp.ppo_dict ---> PPO CARD self.inp = inp # the full user input dictionary self.callback = callback self.mode = self.inp.ppo_dict['mode'][0] self.log_dir = self.inp.gen_dict['log_dir'] set_global_seeds(3)
def make_env(self, env_id, rank, seed=0): """ This function makes multiprocessed/parallel envs based on gym.make with specific seeds env_id: (str) the environment ID num_env: (int) the number of environments you wish to have in subprocesses seed: (int) the inital seed for RNG rank: (int) index of the subprocess Returns: _init, which is a gym enviroment with specific seed """ def _init(): env = gym.make(env_id, casename=self.inp.ppo_dict['casename'][0], exepath=self.inp.gen_dict['exepath'][0], log_dir=self.log_dir, env_data=self.inp.gen_dict['env_data'][0], env_seed=seed+rank) env.seed(seed+rank) return env set_global_seeds(seed) return _init
def main(): """ Run the atari test """ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--prioritized', type=int, default=1) parser.add_argument('--dueling', type=int, default=1) parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) parser.add_argument('--num-timesteps', type=int, default=int(10e6)) args = parser.parse_args() logger.configure() set_global_seeds(args.seed) env = make_atari(args.env) env = Monitor(env, logger.get_dir()) env = wrap_atari_dqn(env) policy = partial(CnnPolicy, dueling=args.dueling == 1) model = DQN( env=env, policy=policy, learning_rate=1e-4, buffer_size=10000, exploration_fraction=0.1, exploration_final_eps=0.01, train_freq=4, learning_starts=10000, target_network_update_freq=1000, gamma=0.99, prioritized_replay=bool(args.prioritized), prioritized_replay_alpha=args.prioritized_replay_alpha, ) model.learn(total_timesteps=args.num_timesteps) env.close()