Exemple #1
0
def run_gcg(params):
    # copy yaml for posterity
    try:
        yaml_path = os.path.join(logger.get_snapshot_dir(),
                                 '{0}.yaml'.format(params['exp_name']))
        with open(yaml_path, 'w') as f:
            f.write(params['txt'])
    except:
        pass

    os.environ["CUDA_VISIBLE_DEVICES"] = str(
        params['policy']['gpu_device'])  # TODO: hack so don't double GPU
    config.USE_TF = True

    normalize_env = params['alg'].pop('normalize_env')

    env_str = params['alg'].pop('env')
    env = create_env(env_str, is_normalize=normalize_env, seed=params['seed'])

    env_eval_str = params['alg'].pop('env_eval', env_str)
    env_eval = create_env(env_eval_str,
                          is_normalize=normalize_env,
                          seed=params['seed'])

    env.reset()
    env_eval.reset()

    #####################
    ### Create policy ###
    #####################

    policy_class = params['policy']['class']
    PolicyClass = eval(policy_class)
    policy_params = params['policy'][policy_class]

    policy = PolicyClass(
        env_spec=env.spec,
        exploration_strategies=params['alg'].pop('exploration_strategies'),
        **policy_params,
        **params['policy'])

    ########################
    ### Create algorithm ###
    ########################

    if 'max_path_length' in params['alg']:
        max_path_length = params['alg'].pop('max_path_length')
    else:
        max_path_length = env.horizon
    algo = GCG(env=env,
               env_eval=env_eval,
               policy=policy,
               max_path_length=max_path_length,
               env_str=env_str,
               **params['alg'])
    algo.train()
Exemple #2
0
    def __init__(self, folder, num_rollouts):
        """
        :param kwargs: holds random extra properties
        """
        self._folder = folder
        self._num_rollouts = num_rollouts

        ### load data
        self.name = os.path.basename(self._folder)
        with open(self._params_file, 'r') as f:
            self.params = yaml.load(f)

        self.env = create_env(self.params['alg']['env'])
Exemple #3
0
    def __init__(self,
                 policy,
                 env,
                 n_envs,
                 replay_pool_size,
                 max_path_length,
                 sampling_method,
                 save_rollouts=False,
                 save_rollouts_observations=True,
                 save_env_infos=False,
                 env_str=None,
                 replay_pool_params={}):
        self._policy = policy
        self._n_envs = n_envs

        assert (self._n_envs == 1)  # b/c policy reset

        self._replay_pools = [
            RNNCriticReplayPool(
                env.spec,
                env.horizon,
                policy.N,
                policy.gamma,
                replay_pool_size // n_envs,
                obs_history_len=policy.obs_history_len,
                sampling_method=sampling_method,
                save_rollouts=save_rollouts,
                save_rollouts_observations=save_rollouts_observations,
                save_env_infos=save_env_infos,
                replay_pool_params=replay_pool_params) for _ in range(n_envs)
        ]

        try:
            envs = [
                pickle.loads(pickle.dumps(env)) for _ in range(self._n_envs)
            ] if self._n_envs > 1 else [env]
        except:
            envs = [create_env(env_str) for _ in range(self._n_envs)
                    ] if self._n_envs > 1 else [env]
        ### need to seed each environment if it is GymEnv
        seed = get_seed()
        if seed is not None and isinstance(utils.inner_env(env), GymEnv):
            for i, env in enumerate(envs):
                utils.inner_env(env).env.seed(seed + i)
        self._vec_env = VecEnvExecutor(envs=envs,
                                       max_path_length=max_path_length)
        self._curr_observations = self._vec_env.reset()