def run_gcg(params): # copy yaml for posterity try: yaml_path = os.path.join(logger.get_snapshot_dir(), '{0}.yaml'.format(params['exp_name'])) with open(yaml_path, 'w') as f: f.write(params['txt']) except: pass os.environ["CUDA_VISIBLE_DEVICES"] = str( params['policy']['gpu_device']) # TODO: hack so don't double GPU config.USE_TF = True normalize_env = params['alg'].pop('normalize_env') env_str = params['alg'].pop('env') env = create_env(env_str, is_normalize=normalize_env, seed=params['seed']) env_eval_str = params['alg'].pop('env_eval', env_str) env_eval = create_env(env_eval_str, is_normalize=normalize_env, seed=params['seed']) env.reset() env_eval.reset() ##################### ### Create policy ### ##################### policy_class = params['policy']['class'] PolicyClass = eval(policy_class) policy_params = params['policy'][policy_class] policy = PolicyClass( env_spec=env.spec, exploration_strategies=params['alg'].pop('exploration_strategies'), **policy_params, **params['policy']) ######################## ### Create algorithm ### ######################## if 'max_path_length' in params['alg']: max_path_length = params['alg'].pop('max_path_length') else: max_path_length = env.horizon algo = GCG(env=env, env_eval=env_eval, policy=policy, max_path_length=max_path_length, env_str=env_str, **params['alg']) algo.train()
def __init__(self, folder, num_rollouts): """ :param kwargs: holds random extra properties """ self._folder = folder self._num_rollouts = num_rollouts ### load data self.name = os.path.basename(self._folder) with open(self._params_file, 'r') as f: self.params = yaml.load(f) self.env = create_env(self.params['alg']['env'])
def __init__(self, policy, env, n_envs, replay_pool_size, max_path_length, sampling_method, save_rollouts=False, save_rollouts_observations=True, save_env_infos=False, env_str=None, replay_pool_params={}): self._policy = policy self._n_envs = n_envs assert (self._n_envs == 1) # b/c policy reset self._replay_pools = [ RNNCriticReplayPool( env.spec, env.horizon, policy.N, policy.gamma, replay_pool_size // n_envs, obs_history_len=policy.obs_history_len, sampling_method=sampling_method, save_rollouts=save_rollouts, save_rollouts_observations=save_rollouts_observations, save_env_infos=save_env_infos, replay_pool_params=replay_pool_params) for _ in range(n_envs) ] try: envs = [ pickle.loads(pickle.dumps(env)) for _ in range(self._n_envs) ] if self._n_envs > 1 else [env] except: envs = [create_env(env_str) for _ in range(self._n_envs) ] if self._n_envs > 1 else [env] ### need to seed each environment if it is GymEnv seed = get_seed() if seed is not None and isinstance(utils.inner_env(env), GymEnv): for i, env in enumerate(envs): utils.inner_env(env).env.seed(seed + i) self._vec_env = VecEnvExecutor(envs=envs, max_path_length=max_path_length) self._curr_observations = self._vec_env.reset()