Пример #1
0
def train(
    _run,
    root_dir,
    exp_name,
    num_env,
    rl_algo,
    learning_rate,
    log_output_formats,
    embed_type,
    embed_path,
    embed_types,
    embed_paths,
    adv_noise_params,
):
    embed_types, embed_paths, adv_noise_params = resolve_embed(
        embed_type, embed_path, embed_types, embed_paths, adv_noise_params
    )

    scheduler = Scheduler(annealer_dict={"lr": ConstantAnnealer(learning_rate)})
    out_dir, logger = setup_logger(root_dir, exp_name, output_formats=log_output_formats)
    log_callbacks, save_callbacks = [], []

    if rl_algo in NO_VECENV and num_env > 1:
        raise ValueError(f"'{rl_algo}' needs 'num_env' set to 1.")

    multi_venv, our_idx = build_env(out_dir, embed_types=embed_types)
    multi_venv = multi_wrappers(multi_venv, log_callbacks=log_callbacks)
    multi_venv = maybe_embed_agent(
        multi_venv,
        our_idx,
        scheduler,
        log_callbacks=log_callbacks,
        embed_types=embed_types,
        embed_paths=embed_paths,
        adv_noise_params=adv_noise_params,
    )
    single_venv = FlattenSingletonVecEnv(multi_venv)
    single_venv = single_wrappers(
        single_venv,
        scheduler,
        our_idx,
        log_callbacks=log_callbacks,
        save_callbacks=save_callbacks,
        embed_paths=embed_paths,
        embed_types=embed_types,
    )

    train_fn = RL_ALGOS[rl_algo]
    res = train_fn(
        env=single_venv,
        out_dir=out_dir,
        learning_rate=scheduler.get_annealer("lr"),
        logger=logger,
        log_callbacks=log_callbacks,
        save_callbacks=save_callbacks,
    )
    single_venv.close()

    return res
Пример #2
0
def train(_run, root_dir, exp_name, num_env, rl_algo, learning_rate, log_output_formats):
    scheduler = Scheduler(annealer_dict={'lr': ConstantAnnealer(learning_rate)})
    out_dir, logger = setup_logger(root_dir, exp_name, output_formats=log_output_formats)
    log_callbacks, save_callbacks = [], []
    pylog.info(f"Log output formats: {logger.output_formats}")

    if rl_algo in NO_VECENV and num_env > 1:
        raise ValueError(f"'{rl_algo}' needs 'num_env' set to 1.")

    multi_venv, our_idx = build_env(out_dir)
    multi_venv = multi_wrappers(multi_venv, log_callbacks=log_callbacks)
    multi_venv = maybe_embed_victim(multi_venv, our_idx, scheduler, log_callbacks=log_callbacks)

    single_venv = FlattenSingletonVecEnv(multi_venv)
    single_venv = single_wrappers(single_venv, scheduler, our_idx,
                                  log_callbacks=log_callbacks, save_callbacks=save_callbacks)

    train_fn = RL_ALGOS[rl_algo]
    res = train_fn(env=single_venv, out_dir=out_dir, learning_rate=scheduler.get_annealer('lr'),
                   logger=logger, log_callbacks=log_callbacks, save_callbacks=save_callbacks)
    single_venv.close()

    return res
def create_multi_agent_curried_policy_wrapper(mon_dir,
                                              env_name,
                                              num_envs,
                                              embed_index,
                                              max_steps,
                                              state_shape=None,
                                              add_zoo=False,
                                              num_zoo=5):
    def episode_limit(env):
        return time_limit.TimeLimit(env, max_episode_steps=max_steps)

    def env_fn(i):
        return make_env(env_name,
                        seed=42,
                        i=i,
                        out_dir=mon_dir,
                        pre_wrappers=[episode_limit])

    vec_env = make_dummy_vec_multi_env(
        [lambda: env_fn(i) for i in range(num_envs)])

    zoo = load_policy(
        policy_path="1",
        policy_type="zoo",
        env=vec_env,
        env_name=env_name,
        index=1 - embed_index,
        transparent_params=None,
    )
    half_env = FakeSingleSpacesVec(vec_env, agent_id=embed_index)
    policies = [
        _get_constant_policy(half_env,
                             constant_value=half_env.action_space.sample(),
                             state_shape=state_shape) for _ in range(10)
    ]
    if add_zoo:
        policies += [zoo] * num_zoo

    policy_wrapper = MultiPolicyWrapper(policies=policies, num_envs=num_envs)

    vec_env = CurryVecEnv(venv=vec_env,
                          policy=policy_wrapper,
                          agent_idx=embed_index,
                          deterministic=False)
    vec_env = FlattenSingletonVecEnv(vec_env)

    yield vec_env, policy_wrapper, zoo
    policy_wrapper.close()
    def _create_lb_tuples(self, env_name, use_debug, victim_index, victim_path,
                          victim_type):
        """Create lookback data structures which are used to compare our episode rollouts against
        those of an environment where a lookback base policy acted instead.

        params victim_index, victim_path, victim_type are the same as in policy_loader.load_policy
        :param use_debug (bool): Use DummyVecEnv instead of SubprocVecEnv
        :return: (list<LookbackTuple>) lb_tuples
        """
        def env_fn(i):
            return make_env(
                env_name,
                0,
                i,
                out_dir='data/lookbacks/',
                pre_wrappers=[GymCompeteToOurs, OldMujocoResettableWrapper])

        lb_tuples = []
        for _ in range(self.lb_num):
            make_vec_env = make_dummy_vec_multi_env if use_debug else make_subproc_vec_multi_env
            multi_venv = make_vec_env(
                [lambda: env_fn(i) for i in range(self.num_envs)])
            if use_debug:
                multi_venv = DebugVenv(multi_venv)

            victim = load_policy(policy_path=victim_path,
                                 policy_type=victim_type,
                                 env=multi_venv,
                                 env_name=env_name,
                                 index=victim_index,
                                 transparent_params=self.transparent_params)

            multi_venv = EmbedVictimWrapper(multi_env=multi_venv,
                                            victim=victim,
                                            victim_index=victim_index,
                                            transparent=True,
                                            deterministic=True)

            single_venv = FlattenSingletonVecEnv(multi_venv)
            data_dict = {
                'state': None,
                'action': None,
                'info': defaultdict(dict)
            }
            lb_tuples.append(LookbackTuple(venv=single_venv, data=data_dict))
        return lb_tuples