Esempio n. 1
0
def scoped_init_procgen():
    from procgen import ProcgenEnv
    env = ProcgenEnv(num_envs=2,
                     env_name="coinrun",
                     num_levels=12,
                     start_level=34)
    after_init = file_descriptor_count()

    env.close()
    return after_init
Esempio n. 2
0
def create_collector_env_gym(seed):

    options = {
        "world_dim": int(16),
        "init_locator_type": int(3),
        "num_goals_green": int(1),
        "num_goals_red": int(1),
        "num_resources_green": int(0),
        "num_resources_red": int(0),
        "num_fuel": int(0),
        "num_obstacles": int(2),
        "goal_max": 20.0,
        "goal_init": 0.0,
        "agent_max_fuel": 100.0,
        "agent_init_fuel": 100.0,
        "agent_max_resources": 100.0,
        "agent_init_resources_green": 20.0,
        "agent_init_resources_red": 10.0,
    }

    kwargs = {
        "start_level":
        seed if seed is not None else 102,
        "num_levels":
        10,
        "additional_obs_spaces": [
            ProcgenEnv.C_Space("state_ship", False, (9, ), float, (-1e6, 1e6)),
            ProcgenEnv.C_Space(
                "state_goals", False,
                ((options["num_goals_green"] + options["num_goals_red"]) *
                 4, ), float, (-1e6, 1e6)),
            ProcgenEnv.C_Space("state_resources", False,
                               ((options["num_resources_green"] +
                                 options["num_resources_red"]) * 4, ), float,
                               (-1e6, 1e6)),
            ProcgenEnv.C_Space("state_obstacles", False,
                               (options["num_obstacles"] * 3, ), float,
                               (-1e6, 1e6))
        ],
        'max_episodes_per_game':
        0,
        "options":
        options
    }

    # env = gym.make('procgen:procgen-collector-v0',**kwargs)

    env = ProcgenEnv(num_envs=9, env_name="collector", **kwargs)
    dtv = DictToVec([
        space.name for space in kwargs["additional_obs_spaces"]
        if np.prod(space.shape) > 0
    ])
    env = StateTransformerEnv(env, dtv)
    return env
Esempio n. 3
0
    def __call__(self, params, agent, env):

        torch.manual_seed(1)
        np.random.seed(1)

        env = ProcgenEnv(env_name="coinrun", render_mode="rgb_array")
        step = 0
        for i in range(100):

            env.act(gym3.types_np.sample(env.ac_space, bshape=(env.num, )))
            rew, obs, first = env.observe()
            print(f"step {step} reward {rew} first {first}")
            step += 1
Esempio n. 4
0
def test_one_env(alt_flag,
                 model,
                 start_level,
                 num_levels,
                 logger,
                 args,
                 env=None):
    ## Modified based on random_ppo.learn
    if not env:
        venv = ProcgenEnv(num_envs=num_envs,
                          env_name=args.env_name,
                          num_levels=num_levels,
                          start_level=start_level,
                          distribution_mode=args.distribution_mode)
        venv = VecExtractDictObs(venv, "rgb")
        venv = VecMonitor(
            venv=venv,
            filename=None,
            keep_buf=100,
        )
        venv = VecNormalize(venv=venv, ob=False)
        env = venv

    runner = TestRunner(env=env,
                        model=model,
                        nsteps=nsteps,
                        gamma=gamma,
                        lam=lam)

    epinfobuf10 = deque(maxlen=10)
    epinfobuf100 = deque(maxlen=100)
    mean_rewards = []
    datapoints = []
    for rollout in range(1, args.nrollouts + 1):
        logger.info('collecting rollouts {}...'.format(rollout))
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run(
            alt_flag)
        epinfobuf10.extend(epinfos)
        epinfobuf100.extend(epinfos)

        rew_mean_10 = safemean([epinfo['r'] for epinfo in epinfobuf10])
        rew_mean_100 = safemean([epinfo['r'] for epinfo in epinfobuf100])
        ep_len_mean_10 = np.nanmean([epinfo['l'] for epinfo in epinfobuf10])
        ep_len_mean_100 = np.nanmean([epinfo['l'] for epinfo in epinfobuf100])

        logger.info('\n----', rollout)
        mean_rewards.append(rew_mean_10)
        logger.logkv('start_level', start_level)
        logger.logkv('eprew10', rew_mean_10)
        logger.logkv('eprew100', rew_mean_100)
        logger.logkv('eplenmean10', ep_len_mean_10)
        logger.logkv('eplenmean100', ep_len_mean_100)
        logger.logkv("misc/total_timesteps", rollout * args.nbatch)

        logger.info('----\n')
        logger.dumpkvs()
    env.close()
    logger.info("Average reward on levels {} ~ {}: {} ".format(
        start_level, start_level + num_levels, mean_rewards))
    return np.mean(mean_rewards)
Esempio n. 5
0
 def __init__(self, vision, sync=False, **kwargs):
     self._vision = vision
     venv = ProcgenEnv(num_envs=1, **kwargs)
     self.combos = list(venv.unwrapped.combos)
     self.last_keys = []
     env = Scalarize(venv)
     super().__init__(env=env, sync=sync, tps=15, display_info=True)
Esempio n. 6
0
 def collect_observations():
     rng = np.random.RandomState(0)
     venv = ProcgenEnv(num_envs=2, env_name=env_name, rand_seed=23)
     obs = venv.reset()
     obses = [obs["rgb"]]
     for _ in range(128):
         obs, _rew, _done, _info = venv.step(
             rng.randint(
                 low=0,
                 high=venv.action_space.n,
                 size=(venv.num_envs,),
                 dtype=np.int32,
             )
         )
         obses.append(obs["rgb"])
     return np.array(obses)
Esempio n. 7
0
def make_env(
	n_envs=32,
	env_name='coinrun',
	start_level=0,
	num_levels=100,
	use_backgrounds=False, #No Background 
	normalize_obs=False,
	distribution_mode="easy", # Train with easy levels
	normalize_reward=True,
	seed=0,
	seed_levels=0
	):
	"""Make environment for procgen experiments"""
	set_global_seeds(seed)
	set_global_log_levels(40)
	env = ProcgenEnv(
		num_envs=n_envs,
		env_name=env_name,
		start_level=start_level,
		num_levels=num_levels,
  		use_generated_assets=True,
		distribution_mode=distribution_mode,
		use_backgrounds=use_backgrounds,
		restrict_themes=not use_backgrounds,
		render_mode='rgb_array',
		rand_seed=seed_levels
	)
	env = VecExtractDictObs(env, "rgb")
	env = VecNormalize(env, ob=normalize_obs, ret=normalize_reward)
	env = TransposeFrame(env)
	env = ScaledFloatFrame(env)
	env = TensorEnv(env)
	
	return env
Esempio n. 8
0
def make_venv(env_id,
              num_envs=4,
              num_levels=0,
              start_level=0,
              distribution_mode='easy'):

    if env_id == 'cartpole-visual-v1':
        venv = gym.vector.make('cartpole-visual-v1',
                               num_envs=num_envs,
                               num_levels=num_levels,
                               start_level=start_level)
        venv.observation_space = gym.spaces.Box(low=0,
                                                high=255,
                                                shape=(3, 64, 64),
                                                dtype=np.uint8)
        venv.action_space = gym.spaces.Discrete(2)

    else:
        venv = ProcgenEnv(env_name=env_id,
                          num_envs=num_envs,
                          num_levels=num_levels,
                          start_level=start_level,
                          distribution_mode=distribution_mode)
        venv = VecExtractDictObs(venv, "rgb")
        venv = TransposeImage(venv)

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False)
    return venv
Esempio n. 9
0
def _make(**env_config):
    env = ProcgenEnv(**env_config)
    env = EpisodeRewardWrapper(env)
    env = RemoveDictObs(env, key="rgb")
    env = ReshapeAction(env)
    env = PermuteShapeObservation(env)
    return env
Esempio n. 10
0
 def _make_env(self, start, size):
     return ProcEnvAdapter(ProcgenEnv(
         num_envs=self.parallel_env,
         start_level=start,
         num_levels=size,
         env_name=self.env_name,
         num_threads=self.num_thread,
         rand_seed=self.seed
     ), self.transforms)
Esempio n. 11
0
def test_all(alt_flag, load_path, logger, args):
    train_end = int(args.train_level)
    config = tf.compat.v1.ConfigProto(
        log_device_placement=True)  #device_count={'GPU':0})
    config.gpu_options.allow_growth = True  #pylint: disable=E1101
    sess = tf.compat.v1.Session(config=config)

    venv = ProcgenEnv(num_envs=num_envs,
                      env_name=args.env_name,
                      num_levels=train_end,
                      start_level=0,
                      distribution_mode=args.distribution_mode)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(
        venv=venv,
        filename=None,
        keep_buf=100,
    )
    venv = VecNormalize(venv=venv, ob=False)

    env = venv
    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    nrollouts = args.total_tsteps // nbatch
    args.nrollouts = nrollouts
    args.nbatch = nbatch

    model = Model(sess=sess,
                  policy=EnsembleCnnPolicy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nbatch_act=nenvs,
                  nbatch_train=nbatch_train,
                  nsteps=nsteps,
                  ent_coef=ent_coef,
                  vf_coef=0.5,
                  max_grad_norm=0.5)
    model.load(load_path)
    logger.info("Model pramas loaded from saved model: ", load_path)

    mean_rewards = []
    ## first, test train performance

    mean_rewards.append(
        test_one_env(alt_flag, model, 0, train_end, logger, args, env=env))

    ## then, test on sampled intervals
    for l in TEST_START_LEVELS:
        mean_rewards.append(
            test_one_env(alt_flag, model, l, 100, logger, args, env=None))

    logger.info("All tests finished, mean reward history: ", mean_rewards)
    return
Esempio n. 12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vision",
                        choices=["agent", "human"],
                        default="human")
    parser.add_argument("--record-dir", help="directory to record movies to")
    parser.add_argument(
        "--distribution-mode",
        default="hard",
        help="which distribution mode to use for the level generation")
    parser.add_argument("--level-seed",
                        type=int,
                        help="select an individual level to use")
    parser.add_argument("--use-generated-assets",
                        help="using autogenerated assets",
                        choices=["yes", "no"],
                        default="no")
    args = parser.parse_args()

    kwargs = {"distribution_mode": args.distribution_mode}
    kwargs["use_generated_assets"] = True if (args.use_generated_assets
                                              == "yes") else False

    if args.level_seed is not None:
        kwargs["start_level"] = args.level_seed
        kwargs["num_levels"] = 1

    world_dim = int(10)
    kwargs["additional_info_spaces"] = [
        ProcgenEnv.C_Space("state", False, (7 + world_dim * world_dim, ),
                           bytes, (0, 255))
    ]

    kwargs["options"] = {
        'world_dim': world_dim,
        'wall_chance': 0.5,
        'fire_chance': 0.3,
        'water_chance': 0.2,
        'num_keys': int(2),
        'num_doors': int(1),
        'with_grid_steps': True,
        'completion_bonus': 10.0,
        'fire_bonus': -5.0,
        'water_bonus': -2.0,
        'action_bonus': -1.0,
    }

    ia = ProcgenInteractive(args.vision, True, env_name="heistpp", **kwargs)

    ia.skip_info_out("state")

    step_cb = HeistppStatePlotter(world_dim, 1)

    ia.add_step_callback(step_cb)

    ia.run(record_dir=args.record_dir)
Esempio n. 13
0
def make_procgen_env(env_name, num_envs, device):
    venv = ProcgenEnv(env_name=env_name,
                      num_envs=num_envs,
                      distribution_mode="easy",
                      num_levels=0,
                      start_level=0)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(venv=venv)
    envs = VecNormalize(venv=venv, norm_obs=False)
    envs = VecPyTorch(envs, device)
    # VecVideoRecorder(envs, f'videos/{experiment_name}', record_video_trigger=lambda x: x % 1000000== 0, video_length=100)
    return envs
Esempio n. 14
0
def make_lr_venv(num_envs, env_name, seeds, device, **kwargs):
    level_sampler = kwargs.get('level_sampler')
    level_sampler_args = kwargs.get('level_sampler_args')

    ret_normalization = not kwargs.get('no_ret_normalization', False)

    if env_name in PROCGEN_ENVS:
        num_levels = kwargs.get('num_levels', 1)
        start_level = kwargs.get('start_level', 0)
        distribution_mode = kwargs.get('distribution_mode', 'easy')
        paint_vel_info = kwargs.get('paint_vel_info', False)

        venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, \
            num_levels=num_levels, start_level=start_level, \
            distribution_mode=distribution_mode,
            paint_vel_info=paint_vel_info)
        venv = VecExtractDictObs(venv, "rgb")
        venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
        venv = VecNormalize(venv=venv, ob=False, ret=ret_normalization)

        if level_sampler_args:
            level_sampler = LevelSampler(
                seeds, 
                venv.observation_space, venv.action_space,
                **level_sampler_args)

        envs = VecPyTorchProcgen(venv, device, level_sampler=level_sampler)

    elif env_name.startswith('MiniGrid'):
        venv = VecMinigrid(num_envs=num_envs, env_name=env_name, seeds=seeds)
        venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
        venv = VecNormalize(venv=venv, ob=False, ret=ret_normalization)

        if level_sampler_args:
            level_sampler = LevelSampler(
                seeds, 
                venv.observation_space, venv.action_space,
                **level_sampler_args)

        elif seeds:
            level_sampler = LevelSampler(
                seeds,
                venv.observation_space, venv.action_space,
                strategy='random',
            )

        envs = VecPyTorchMinigrid(venv, device, level_sampler=level_sampler)

    else:
        raise ValueError(f'Unsupported env {env_name}')

    return envs, level_sampler
Esempio n. 15
0
def make_env(env_name, num_processes, device, num_levels, start_level,
             distribution_mode):
    print('make_env')
    venv = ProcgenEnv(env_name=env_name,
                      num_envs=num_processes,
                      num_levels=num_levels,
                      start_level=start_level,
                      distribution_mode=distribution_mode)
    venv = VecExtractDictObs(venv, 'rgb')
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False)
    venv = VecPyTorchProcgen(venv, device)

    return venv
Esempio n. 16
0
    def __init__(self, model, config_dir: pathlib.Path, n_trajectories: int, tunable_params: List[EnvironmentParameter]):
        self._model = model
        self._n_trajectories = n_trajectories

        # Initialize the environment
        easy_config_path = config_dir / 'test_easy_config.json'
        easy_config = copy.copy(BossfightEasyConfig)
        easy_config.to_json(easy_config_path)
        easy_env = ProcgenEnv(num_envs=1, env_name=str(easy_config.game), domain_config_path=str(easy_config_path))
        easy_env = VecExtractDictObs(easy_env, "rgb")
        easy_env = VecMonitor(venv=easy_env, filename=None, keep_buf=100)
        self.easy_env = VecNormalize(venv=easy_env, ob=False)

        hard_config_path = config_dir / 'test_hard_config.json'
        hard_config = copy.copy(BossfightHardConfig)
        hard_config.to_json(hard_config_path)
        hard_env = ProcgenEnv(num_envs=1, env_name=str(hard_config.game), domain_config_path=str(hard_config_path))
        hard_env = VecExtractDictObs(hard_env, "rgb")
        hard_env = VecMonitor(venv=hard_env, filename=None, keep_buf=100)
        self.hard_env = VecNormalize(venv=hard_env, ob=False)

        # Make a default config for bossfight...
        test_domain_config_path = config_dir / 'test_full_config.json'
        test_domain_config = DEFAULT_DOMAIN_CONFIGS['dc_bossfight']
        test_domain_config.to_json(test_domain_config_path)

        params = {}
        for param in tunable_params:
            params['min_' + param.name] = param.clip_lower_bound
            params['max_' + param.name] = param.clip_upper_bound
        test_domain_config.update_parameters(params, cache=False)

        full_env = ProcgenEnv(num_envs=1, env_name=str(test_domain_config.game), domain_config_path=str(test_domain_config_path))
        full_env = VecExtractDictObs(full_env, "rgb")
        full_env = VecMonitor(venv=full_env, filename=None, keep_buf=100)
        self.full_env = VecNormalize(venv=full_env, ob=False)
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vision", choices=["agent", "human"], default="human")
    parser.add_argument("--record-dir", help="directory to record movies to")
    parser.add_argument("--distribution-mode", default="hard", help="which distribution mode to use for the level generation")
    parser.add_argument("--level-seed", type=int, help="select an individual level to use")
    parser.add_argument("--use-generated-assets", help="using autogenerated assets", choices=["yes","no"], default="no")
    args = parser.parse_args()

    kwargs = {"distribution_mode": args.distribution_mode}
    kwargs["use_generated_assets"] = True if (args.use_generated_assets == "yes") else False
    if args.level_seed is not None:
        kwargs["start_level"] = args.level_seed
        kwargs["num_levels"] = 1

    world_dim = int(10)
    kwargs["additional_info_spaces"] = [ProcgenEnv.C_Space("state", False, (7+world_dim*world_dim,), bytes, (0,255))]

    kwargs["options"] = {
        'world_dim':world_dim,
        'wall_chance':0.5,
        'fire_chance':0.3,
        'water_chance':0.2,
        'num_keys':int(2),
        'num_doors':int(1),
        'with_grid_steps':True,
        'completion_bonus':10.0,
        'fire_bonus':-5.0,
        'water_bonus':-2.0,
        'action_bonus':-1.0,
        }

    env = gym.make('procgen:procgen-heistpp-v0')

    # env = ProcgenEnv(num_envs=1, env_name="heistpp", **kwargs)

    obs = env.reset()
    step = 0
    while True:
        env.render()
        obs, rew, done, info = env.step(np.array([env.action_space.sample()]))
        print(f"step {step} reward {rew} done {done}")
        # print(info[0]['state'])
        step += 1
        if done:
            break

    env.close()
Esempio n. 18
0
def SB3_and_ProcgenEnv_example():
    # SB3 and ProcgenEnv.

    from procgen import ProcgenEnv

    # ProcgenEnv is already vectorized.
    venv = ProcgenEnv(num_envs=2, env_name="starpilot")

    # To use only part of the observation:
    #venv = VecExtractDictObs(venv, "rgb")

    # Wrap with a VecMonitor to collect stats and avoid errors.
    venv = VecMonitor(venv=venv)

    model = PPO("MultiInputPolicy", venv, verbose=1)
    model.learn(10_000)
Esempio n. 19
0
def evaluate(args, actor_critic, device, num_processes=1, aug_id=None):
    actor_critic.eval()

    # Sample Levels From the Full Distribution
    venv = ProcgenEnv(num_envs=num_processes, env_name=args.env_name, \
        num_levels=0, start_level=0, \
        distribution_mode=args.distribution_mode)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False)
    eval_envs = VecPyTorchProcgen(venv, device)

    eval_episode_rewards = []

    obs = eval_envs.reset()
    eval_recurrent_hidden_states = torch.zeros(
        num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.ones(num_processes, 1, device=device)

    while len(eval_episode_rewards) < 10:
        with torch.no_grad():
            if aug_id:
                obs = aug_id(obs)
            _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                obs,
                eval_recurrent_hidden_states,
                eval_masks,
                deterministic=False)

        obs, _, done, infos = eval_envs.step(action)

        eval_masks = torch.tensor([[0.0] if done_ else [1.0]
                                   for done_ in done],
                                  dtype=torch.float32,
                                  device=device)

        for info in infos:
            if 'episode' in info.keys():
                eval_episode_rewards.append(info['episode']['r'])

    eval_envs.close()

    print("Last {} test episodes: mean/median reward {:.1f}/{:.1f}\n"\
        .format(len(eval_episode_rewards), \
        np.mean(eval_episode_rewards), np.median(eval_episode_rewards)))

    return eval_episode_rewards
Esempio n. 20
0
    def __init__(self, config={}):
        self.config = default_config
        self.config.update(config)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        nstacks = self.config.pop('nstacks')

        env = ProcgenEnv(**self.config)
        env = VecExtractDictObs(env, "rgb")
        obs_shape = env.observation_space.shape

        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        env = VecPyTorch(env, device)
        env = VecPyTorchFrameStack(env, nstacks, device)
        self.venv = env
Esempio n. 21
0
def make_vec_envs_procgen(env_name,
                          num_envs,
                          start_level=0,
                          num_levels=0,
                          distribution_mode='hard',
                          normalize_obs=False,
                          normalize_ret=True,
                          num_frame_stack=1):
    env = ProcgenEnv(num_envs=num_envs,
                     env_name=env_name,
                     start_level=start_level,
                     num_levels=num_levels,
                     distribution_mode=distribution_mode)
    env = VecExtractDictTransposedObs(env, 'rgb')
    env = VecFrameStack(env, num_frame_stack)
    env = VecMonitor(env)
    env = VecNormalize(env, obs=normalize_obs, ret=normalize_ret)
    return env
Esempio n. 22
0
    def __init__(self, model, train_config_path: Union[str, pathlib.Path],
                 env_parameter: EnvironmentParameter, adr_config: ADRConfig):

        self._model = model  # Model being evaluated
        self._gamma = adr_config.gamma  # Discount rate
        self._lambda = adr_config.lmbda  # Lambda used in GAE (General Advantage Estimation)

        self._env_parameter = env_parameter
        self._param_name = self._env_parameter.name

        self._max_buffer_size = adr_config.max_buffer_size
        self._n_trajectories = adr_config.n_eval_trajectories
        self._upper_sample_prob = adr_config.upper_sample_prob

        self._train_config_path = pathlib.Path(train_config_path)
        config_dir = self._train_config_path.parent
        config_name = self._param_name + '_adr_eval_config.json'

        # Initialize the config for the evaluation environment
        # This config will be updated regularly throughout training. When we boundary sample this environment's
        # parameter, the config will be modified to set the parameter to the selected boundary before running a number
        # of trajectories.
        self._boundary_config = DomainConfig.from_json(self._train_config_path)
        self._boundary_config_path = config_dir / config_name
        self._boundary_config.to_json(self._boundary_config_path)

        # Initialize the environment
        env = ProcgenEnv(num_envs=1,
                         env_name=str(self._boundary_config.game),
                         domain_config_path=str(self._boundary_config_path))
        env = VecExtractDictObs(env, "rgb")
        env = VecMonitor(venv=env, filename=None, keep_buf=100)
        self._env = VecNormalize(venv=env, ob=False)

        # Initialize the performance buffers
        self._upper_performance_buffer, self._lower_performance_buffer = PerformanceBuffer(
        ), PerformanceBuffer()

        self._states = {
            'lower': model.adr_initial_state,
            'upper': model.adr_initial_state
        }
        self._obs = self._env.reset()
        self._dones = [False]
Esempio n. 23
0
def test_multi_speed(env_name, num_envs, benchmark):
    venv = ProcgenEnv(num_envs=num_envs, env_name=env_name)

    venv.reset()
    actions = np.zeros([venv.num_envs])

    def rollout(max_steps):
        step_count = 0
        while step_count < max_steps:
            _obs, _rews, _dones, _infos = venv.step(actions)
            step_count += 1

    benchmark(lambda: rollout(1000))

    venv.close()
Esempio n. 24
0
def test_fn(env_name, num_envs, config_path, load_path):
    test_config_path = os.path.join(os.getcwd(), "procgen-adr", config_path)
    test_env = ProcgenEnv(num_envs=num_envs, env_name=env_name, domain_config_path=test_config_path, render_mode="rgb_array")
    test_env = VecExtractDictObs(test_env, "rgb")
    test_env = VecMonitor(venv=test_env, filename=None, keep_buf=100)
    test_env = VecNormalize(venv=test_env, ob=False)

    setup_mpi_gpus()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #pylint: disable=E1101
    sess = tf.Session(config=config)
    sess.__enter__()

    conv_fn = lambda x: build_impala_cnn(x, depths=[16,32,32], emb_size=256)

    recur = True
    if recur:
        logger.info("Using CNN LSTM")
        conv_fn = cnn_lstm(nlstm=256, conv_fn=conv_fn)

    mean, std = test(conv_fn, test_env, load_path=load_path)
    sess.close()
    return mean, std
Esempio n. 25
0
def train(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_dir = os.path.expanduser(args.log_dir)
    utils.cleanup_log_dir(log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    log_file = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name,
                                             args.seed)

    venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \
        num_levels=args.num_levels, start_level=args.start_level, \
        distribution_mode=args.distribution_mode)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecNormalize(venv=venv, ob=False)
    envs = VecPyTorchProcgen(venv, device)

    obs_shape = envs.observation_space.shape

    ################################
    actor_critic = Policy(obs_shape,
                          envs.action_space.n,
                          base_kwargs={
                              'recurrent': False,
                              'hidden_size': args.hidden_size
                          })
    actor_critic.to(device)

    ################################
    rollouts = RolloutStorage(args.num_steps,
                              args.num_processes,
                              envs.observation_space.shape,
                              envs.action_space,
                              actor_critic.recurrent_hidden_state_size,
                              aug_type=args.aug_type,
                              split_ratio=args.split_ratio)

    batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch)

    ################################
    if args.use_ucb:
        aug_id = data_augs.Identity
        aug_list = [
            aug_to_func[t](batch_size=batch_size)
            for t in list(aug_to_func.keys())
        ]

        agent = algo.UCBDrAC(actor_critic,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             max_grad_norm=args.max_grad_norm,
                             aug_list=aug_list,
                             aug_id=aug_id,
                             aug_coef=args.aug_coef,
                             num_aug_types=len(list(aug_to_func.keys())),
                             ucb_exploration_coef=args.ucb_exploration_coef,
                             ucb_window_length=args.ucb_window_length)

    elif args.use_meta_learning:
        aug_id = data_augs.Identity
        aug_list = [aug_to_func[t](batch_size=batch_size) \
            for t in list(aug_to_func.keys())]

        aug_model = AugCNN()
        aug_model.to(device)

        agent = algo.MetaDrAC(actor_critic,
                              aug_model,
                              args.clip_param,
                              args.ppo_epoch,
                              args.num_mini_batch,
                              args.value_loss_coef,
                              args.entropy_coef,
                              meta_grad_clip=args.meta_grad_clip,
                              meta_num_train_steps=args.meta_num_train_steps,
                              meta_num_test_steps=args.meta_num_test_steps,
                              lr=args.lr,
                              eps=args.eps,
                              max_grad_norm=args.max_grad_norm,
                              aug_id=aug_id,
                              aug_coef=args.aug_coef)

    elif args.use_rl2:
        aug_id = data_augs.Identity
        aug_list = [
            aug_to_func[t](batch_size=batch_size)
            for t in list(aug_to_func.keys())
        ]

        rl2_obs_shape = [envs.action_space.n + 1]
        rl2_learner = Policy(rl2_obs_shape,
                             len(list(aug_to_func.keys())),
                             base_kwargs={
                                 'recurrent': True,
                                 'hidden_size': args.rl2_hidden_size
                             })
        rl2_learner.to(device)

        agent = algo.RL2DrAC(actor_critic,
                             rl2_learner,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             args.rl2_entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             rl2_lr=args.rl2_lr,
                             rl2_eps=args.rl2_eps,
                             max_grad_norm=args.max_grad_norm,
                             aug_list=aug_list,
                             aug_id=aug_id,
                             aug_coef=args.aug_coef,
                             num_aug_types=len(list(aug_to_func.keys())),
                             recurrent_hidden_size=args.rl2_hidden_size,
                             num_actions=envs.action_space.n,
                             device=device)

    else:
        aug_id = data_augs.Identity
        aug_func = aug_to_func[args.aug_type](batch_size=batch_size)

        agent = algo.DrAC(actor_critic,
                          args.clip_param,
                          args.ppo_epoch,
                          args.num_mini_batch,
                          args.value_loss_coef,
                          args.entropy_coef,
                          lr=args.lr,
                          eps=args.eps,
                          max_grad_norm=args.max_grad_norm,
                          aug_id=aug_id,
                          aug_func=aug_func,
                          aug_coef=args.aug_coef,
                          env_name=args.env_name)

    checkpoint_path = os.path.join(args.save_dir, "agent" + log_file + ".pt")
    if os.path.exists(checkpoint_path) and args.preempt:
        checkpoint = torch.load(checkpoint_path)
        agent.actor_critic.load_state_dict(checkpoint['model_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        init_epoch = checkpoint['epoch'] + 1
        logger.configure(dir=args.log_dir,
                         format_strs=['csv', 'stdout'],
                         log_suffix=log_file + "-e%s" % init_epoch)
    else:
        init_epoch = 0
        logger.configure(dir=args.log_dir,
                         format_strs=['csv', 'stdout'],
                         log_suffix=log_file)

    obs = envs.reset()  # envs!!!!!!!!!!
    rollouts.obs[0].copy_(obs)  # 초기 obs 장착
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    # args.num_steps -> 256, 'number of forward steps in A2C')
    # args.num_env_steps -> 25e6, 'number of environment steps to train'
    num_updates = int(
        args.num_env_steps) // args.num_processes // args.num_steps

    # todo : 에폭이라... 그런데 이거 에피소드마다 종료되는 스탭이 다를텐데...
    for j in range(init_epoch, num_updates):
        actor_critic.train()
        for step in range(args.num_steps):

            # Sample actions
            with torch.no_grad():
                obs_id = aug_id(rollouts.obs[step])
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    obs_id, rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observe reward and next obs
            # todo : check the shapes of obs, reward, done, infos
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    # todo : difference between reward and info['episode']['r']
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            obs_id = aug_id(rollouts.obs[-1])
            # todo : what is next_value for?
            next_value = actor_critic.get_value(
                obs_id, rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.gamma, args.gae_lambda)

        if args.use_ucb and j > 0:  # from second epoch
            agent.update_ucb_values(rollouts)  # update ucb

        # todo : 와 여기가 장난아니네 ㅠㅠ
        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        # 뭔가 클리어!
        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        total_num_steps = (j + 1) * args.num_processes * args.num_steps
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}"
                .format(j, total_num_steps, len(episode_rewards),
                        np.mean(episode_rewards), np.median(episode_rewards),
                        dist_entropy, value_loss, action_loss))

            logger.logkv("train/nupdates", j)
            logger.logkv("train/total_num_steps", total_num_steps)

            logger.logkv("losses/dist_entropy", dist_entropy)
            logger.logkv("losses/value_loss", value_loss)
            logger.logkv("losses/action_loss", action_loss)

            logger.logkv("train/mean_episode_reward", np.mean(episode_rewards))
            logger.logkv("train/median_episode_reward",
                         np.median(episode_rewards))

            ### Eval on the Full Distribution of Levels ###
            eval_episode_rewards = evaluate(args,
                                            actor_critic,
                                            device,
                                            aug_id=aug_id)

            logger.logkv("test/mean_episode_reward",
                         np.mean(eval_episode_rewards))
            logger.logkv("test/median_episode_reward",
                         np.median(eval_episode_rewards))

            logger.dumpkvs()

        # Save Model
        if (j > 0 and j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            try:
                os.makedirs(args.save_dir)
            except OSError:
                pass

            torch.save(
                {
                    'epoch': j,
                    'model_state_dict': agent.actor_critic.state_dict(),
                    'optimizer_state_dict': agent.optimizer.state_dict(),
                }, os.path.join(args.save_dir, "agent" + log_file + ".pt"))
Esempio n. 26
0
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = ProcgenEnv(num_envs=args.num_envs, env_name=args.gym_id, num_levels=0, start_level=0, distribution_mode="easy")
    envs = gym.wrappers.TransformObservation(envs, lambda obs: obs["rgb"])
    envs.single_action_space = envs.action_space
    envs.single_observation_space = envs.observation_space["rgb"]
    envs.is_vector_env = True
    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    envs = gym.wrappers.NormalizeReward(envs)
    envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10))
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
Esempio n. 27
0
    eps = .2
    
    set_global_seeds(seed)
    set_global_log_levels(log_level)

    ### DEVICE ###
    device = torch.device('cuda')
    #print("dev", device)
   
    ### ENVIRONMENT ###
    print('INITIALIZAING THE ENVIRONMENTS...............')
    n_steps = 256
    n_envs = 64
    torch.set_num_threads(1)    #increasing the number of threads will usually leads to faster execution on CPU
    env = ProcgenEnv(num_envs=n_envs,env_name=env_name,start_level=start_level,
                    num_levels=num_levels,distribution_mode=distribution_mode,
                    use_backgrounds=False,restrict_themes=True)
    
    normalize_reward = True
    env = VecExtractDictObs(env, "rgb")
    if normalize_reward:
        env = VecNormalize(env, ob=False)  # normalizing returns, but not the img frames
    env = TransposeFrame(env)
    env = ScaledFloatFrame(env)

    v_env = ProcgenEnv(num_envs=n_envs,env_name=env_name,start_level=start_level,
                    num_levels=num_levels,distribution_mode=distribution_mode,
                    use_backgrounds=False,restrict_themes=True)
    v_env = VecExtractDictObs(v_env, "rgb")
    if normalize_reward:
        v_env = VecNormalize(v_env, ob=False)
def main():

    args = parse_config()
    run_dir = log_this(args, args.log_dir,
                       args.log_name + '_' + args.env_name + '_' + args.rm_id)

    test_worker_interval = args.test_worker_interval

    comm = MPI.COMM_WORLD

    is_test_worker = False

    if test_worker_interval > 0:
        is_test_worker = comm.Get_rank() % test_worker_interval == (
            test_worker_interval - 1)

    mpi_rank_weight = 0 if is_test_worker else 1

    log_comm = comm.Split(1 if is_test_worker else 0, 0)
    format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else []
    logger.configure(dir=run_dir, format_strs=format_strs)

    logger.info("creating environment")

    venv = ProcgenEnv(num_envs=args.num_envs,
                      env_name=args.env_name,
                      num_levels=args.num_levels,
                      start_level=args.start_level,
                      distribution_mode=args.distribution_mode,
                      use_sequential_levels=args.use_sequential_levels)
    venv = VecExtractDictObs(venv, "rgb")
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)

    if args.rm_id:
        # load pretrained network
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        net = RewardNet().to(device)
        rm_path = glob.glob('./**/' + args.rm_id + '.rm', recursive=True)[0]
        net.load_state_dict(
            torch.load(rm_path, map_location=torch.device(device)))

        # use batch reward prediction function instead of the ground truth reward function
        # pass though sigmoid if needed
        if args.use_sigmoid:
            rew_func = lambda x: 1 / (1 + np.exp(-net.predict_batch_rewards(x))
                                      )
        else:
            rew_func = lambda x: net.predict_batch_rewards(x)

        ## Uncomment the line below to train a live-long agent
        # rew_func = lambda x: x.shape[0] * [1]

        venv = ProxyRewardWrapper(venv, rew_func)
    else:
        # true environment rewards will be use
        pass

    venv = VecNormalize(venv=venv, ob=False, use_tf=False)

    # do the rest of the training as normal
    logger.info("creating tf session")
    setup_mpi_gpus()
    config = tf.ConfigProto()

    config.gpu_options.allow_growth = True  #pylint: disable=E1101
    sess = tf.Session(config=config)

    sess.__enter__()

    conv_fn = lambda x: build_impala_cnn(x, depths=[16, 32, 32], emb_size=256)

    logger.info("training")

    model = ppo2.learn(
        env=venv,
        network=conv_fn,
        total_timesteps=args.timesteps_per_proc,
        save_interval=args.save_interval,
        nsteps=args.nsteps,
        nminibatches=args.nminibatches,
        lam=args.lam,
        gamma=args.gamma,
        noptepochs=args.ppo_epochs,
        log_interval=args.log_interval,
        ent_coef=args.ent_coef,
        mpi_rank_weight=mpi_rank_weight,
        clip_vf=args.use_vf_clipping,
        comm=comm,
        lr=args.learning_rate,
        cliprange=args.clip_range,
        update_fn=None,
        init_fn=None,
        vf_coef=0.5,
        max_grad_norm=0.5,
        load_path=args.load_path,
    )

    model.save(os.path.join(run_dir, 'final_model.parameters'))
Esempio n. 29
0
def main():

    parser = argparse.ArgumentParser(
        description='Process procgen training arguments.')
    parser.add_argument('--env_name', type=str, default='fruitbot')
    parser.add_argument(
        '--distribution_mode',
        type=str,
        default='easy',
        choices=["easy", "hard", "exploration", "memory", "extreme"])
    parser.add_argument('--num_levels', type=int, default=50)
    parser.add_argument('--start_level', type=int, default=0)
    parser.add_argument('--test_worker_interval', type=int, default=0)
    parser.add_argument('--run_id', '-id', type=int, default=99)
    parser.add_argument('--nupdates', type=int, default=0)
    parser.add_argument('--total_tsteps', type=int, default=0)
    parser.add_argument('--log_interval', type=int, default=5)
    parser.add_argument('--load_id', type=int, default=int(-1))
    parser.add_argument('--nrollouts', '-nroll', type=int, default=0)
    parser.add_argument('--test', default=False, action="store_true")
    parser.add_argument('--use_model',
                        type=int,
                        default=1,
                        help="either model #1 or #2")
    parser.add_argument('--train_level', type=int, default=50)

    args = parser.parse_args()
    #timesteps_per_proc
    if args.nupdates:
        timesteps_per_proc = int(args.nupdates * num_envs * nsteps)
    if not args.total_tsteps:
        args.total_tsteps = TIMESTEPS_PER_PROC  ## use global 20_000_000 if not specified in args!
    if args.nrollouts:
        total_timesteps = int(args.nrollouts * num_envs * nsteps)

    run_ID = 'run_' + str(args.run_id).zfill(2)
    if args.test:
        args.log_interval = 1
        args.total_tsteps = 1_000_000
        run_ID += '_test{}_model{}'.format(args.load_id, args.use_model)

    load_path = None
    if args.load_id > -1:
        load_path = join(SAVE_PATH, args.env_name,
                         'saved_ensemble2_v{}.tar'.format(args.load_id))

    test_worker_interval = args.test_worker_interval
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    is_test_worker = False
    if test_worker_interval > 0:
        is_test_worker = comm.Get_rank() % test_worker_interval == (
            test_worker_interval - 1)

    mpi_rank_weight = 0 if is_test_worker else 1
    num_levels = 0 if is_test_worker else args.num_levels

    log_comm = comm.Split(1 if is_test_worker else 0, 0)
    format_strs = ['csv', 'stdout', 'log'] if log_comm.Get_rank() == 0 else []

    if args.test:
        logpath = join('log2/ensemble2', args.env_name, 'test', run_ID)
    else:
        logpath = join('log2/ensemble2', args.env_name, 'train', run_ID)
        save_path = join(SAVE_PATH, args.env_name,
                         'saved_ensemble2_v{}.tar'.format(args.run_id))
        logger.info("\n Model will be saved to file {}".format(save_path))

    if not os.path.exists(logpath):
        os.system("mkdir -p %s" % logpath)
    logger.configure(dir=logpath, format_strs=format_strs)

    fpath = join(logpath, 'args_{}.json'.format(run_ID))
    with open(fpath, 'w') as fh:
        json.dump(vars(args), fh, indent=4, sort_keys=True)
    print("\nSaved args at:\n\t{}\n".format(fpath))

    logger.info("creating tf session")
    setup_mpi_gpus()

    if not args.test:
        config = tf.compat.v1.ConfigProto(\
            allow_soft_placement=True,
            log_device_placement=True)# device_count={'GPU':0})
        config.gpu_options.allow_growth = True  #pylint: disable=E1101
        sess = tf.compat.v1.Session(config=config)
        logger.info("creating 2 environments")
        n_levels = int(args.num_levels / 2)
        env1 = ProcgenEnv(num_envs=num_envs,
                          env_name=args.env_name,
                          num_levels=n_levels,
                          start_level=0,
                          distribution_mode=args.distribution_mode)
        env1 = VecExtractDictObs(env1, "rgb")
        env1 = VecMonitor(
            venv=env1,
            filename=None,
            keep_buf=100,
        )
        env1 = VecNormalize(venv=env1, ob=False)

        env2 = ProcgenEnv(num_envs=num_envs,
                          env_name=args.env_name,
                          num_levels=n_levels,
                          start_level=n_levels,
                          distribution_mode=args.distribution_mode)
        env2 = VecExtractDictObs(env2, "rgb")
        env2 = VecMonitor(
            venv=env2,
            filename=None,
            keep_buf=100,
        )
        env2 = VecNormalize(venv=env2, ob=False)

        train(run_ID, save_path, load_path, env1, env2, sess, logger, args)
    else:
        use_model = args.use_model  ## 1 or 2
        alt_flag = use_model - 1
        test_all(alt_flag, load_path, logger, args)
Esempio n. 30
0
def train_fn(env_name: str,
             num_train_envs: int,
             n_training_steps: int,
             adr_config: ADRConfig = None,
             experiment_dir: str = None,
             tunable_params_config_path: str = None,
             log_dir: str = None,
             is_test_worker: bool = False,
             comm=None,
             save_interval: int = 1000,
             log_interval: int = 20,
             recur: bool = True):

    # Get the default ADR config if none is provided
    adr_config = ADRConfig() if adr_config is None else adr_config

    # Set up the experiment directory for this run. This will contain everything, from the domain configs for the
    # training environment and ADR evaluation environments to the logs. If the directory path is not provided, then
    # we'll make one an use the date-time-name to make it unique
    if experiment_dir is None:
        experiment_dir = pathlib.Path().absolute() / 'adr_experiments' / (
            'experiment-' + datetime_name())
        experiment_dir.mkdir(parents=True, exist_ok=False)
    else:
        experiment_dir = pathlib.Path(experiment_dir)

    # Make a config directory within the experiment directory to hold the domain configs
    config_dir = experiment_dir / 'domain_configs'
    config_dir.mkdir(parents=True, exist_ok=False)

    # Load the tunable parameters from a config file if it is provided, otherwise get the default for the given game.
    if tunable_params_config_path is None:
        try:
            tunable_params = DEFAULT_TUNABLE_PARAMS[env_name]
        except KeyError:
            raise KeyError(
                f'No default tunable parameters exist for {env_name}')
    else:
        raise NotImplemented(
            'Currently no way to load tunable parameters from a configuration file'
        )

    # Make a default config for the given game...
    train_domain_config_path = config_dir / 'train_config.json'
    try:
        train_domain_config = DEFAULT_DOMAIN_CONFIGS[env_name]
        train_domain_config.to_json(train_domain_config_path)
    except KeyError:
        raise KeyError(f'No default config exists for {env_name}')

    # ...then load the initial bounds for the tunable parameters into the config.
    params = {}
    for param in tunable_params:
        params['min_' + param.name] = param.lower_bound
        params['max_' + param.name] = param.upper_bound
    train_domain_config.update_parameters(params, cache=False)

    # Configure the logger if we are given a log directory
    if log_dir is not None:
        log_dir = experiment_dir / log_dir
        log_comm = comm.Split(1 if is_test_worker else 0, 0)
        format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else []
        logger.configure(comm=log_comm,
                         dir=str(log_dir),
                         format_strs=format_strs)

    logger.info(f'env_name: {env_name}')
    logger.info(f'num_train_envs: {num_train_envs}')
    logger.info(f'n_training_steps: {n_training_steps}')
    logger.info(f'experiment_dir: {experiment_dir}')
    logger.info(f'tunable_params_config_path: {tunable_params_config_path}')
    logger.info(f'log_dir: {log_dir}')
    logger.info(f'save_interval: {save_interval}')

    n_steps = 256
    ent_coef = .01
    lr = 5e-4
    vf_coef = .5
    max_grad_norm = .5
    gamma = .999
    lmbda = .95
    n_minibatches = 8
    ppo_epochs = 3
    clip_range = .2
    use_vf_clipping = True

    mpi_rank_weight = 0 if is_test_worker else 1

    logger.info('creating environment')
    training_env = ProcgenEnv(num_envs=num_train_envs,
                              env_name=env_name,
                              domain_config_path=str(train_domain_config_path))
    training_env = VecExtractDictObs(training_env, "rgb")
    training_env = VecMonitor(venv=training_env, filename=None, keep_buf=100)
    training_env = VecNormalize(venv=training_env, ob=False)

    logger.info("creating tf session")
    setup_mpi_gpus()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.__enter__()

    def conv_fn(x):
        return build_impala_cnn(x, depths=[16, 32, 32], emb_size=256)

    if recur:
        logger.info("Using CNN LSTM")
        conv_fn = cnn_lstm(nlstm=256, conv_fn=conv_fn)

    logger.info('training')
    ppo2_adr.learn(conv_fn,
                   training_env,
                   n_training_steps,
                   config_dir,
                   adr_config,
                   train_domain_config,
                   tunable_params,
                   n_steps=n_steps,
                   ent_coef=ent_coef,
                   lr=lr,
                   vf_coef=vf_coef,
                   max_grad_norm=max_grad_norm,
                   gamma=gamma,
                   lmbda=lmbda,
                   log_interval=log_interval,
                   save_interval=save_interval,
                   n_minibatches=n_minibatches,
                   n_optepochs=ppo_epochs,
                   clip_range=clip_range,
                   mpi_rank_weight=mpi_rank_weight,
                   clip_vf=use_vf_clipping)