コード例 #1
0
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
    mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    if env_type == 'atari':
        print("making atari")
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id)

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #2
0
ファイル: cmd_util.py プロジェクト: grockious/deepsynth
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
    wrapper_kwargs = wrapper_kwargs or {}
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #3
0
ファイル: enjoy_retro.py プロジェクト: dineshj1/baselines-mod
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env',
                        help='environment ID',
                        default='SuperMarioBros-Nes')
    parser.add_argument('--gamestate',
                        help='game state to load',
                        default='Level1-1')
    parser.add_argument('--model',
                        help='model pickle file from ActWrapper.save',
                        default='model.pkl')
    args = parser.parse_args()

    env = retro_wrappers.make_retro(game=args.env,
                                    state=args.gamestate,
                                    max_episode_steps=None)
    env = retro_wrappers.wrap_deepmind_retro(env)
    act = deepq.load(args.model)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            action = act(obs[None])[0]
            env_action = np.zeros(env.action_space.n)
            env_action[action] = 1
            obs, rew, done, _ = env.step(env_action)
            episode_rew += rew
        print('Episode reward', episode_rew)
コード例 #4
0
def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)

    if env_type == 'atari':
        if alg == 'acer':
            env = make_vec_env(env_id, env_type, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    else:
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        env = make_vec_env(env_id,
                           env_type,
                           args.num_env or 1,
                           seed,
                           reward_scale=args.reward_scale)

        if env_type == 'mujoco':
            env = VecNormalize(env)

    return env
コード例 #5
0
ファイル: run_retro.py プロジェクト: dineshj1/baselines-mod
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes')
    parser.add_argument('--gamestate', help='game state to load', default='Level1-1')
    parser.add_argument('--seed', help='seed', type=int, default=0)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
    env.seed(args.seed)
    env = bench.Monitor(env, logger.get_dir())
    env = retro_wrappers.wrap_deepmind_retro(env)

    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=True
    )
    act.save()
    env.close()
コード例 #6
0
ファイル: cmd_util.py プロジェクト: MrGoogol/baselines
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
    wrapper_kwargs = wrapper_kwargs or {}
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #7
0
ファイル: envs.py プロジェクト: xkianteb/dril
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        elif env_id in ['duckietown']:
            from a2c_ppo_acktr.duckietown.env import launch_env
            from a2c_ppo_acktr.duckietown.wrappers import NormalizeWrapper, ImgWrapper,\
                 DtRewardWrapper, ActionWrapper, ResizeWrapper
            from a2c_ppo_acktr.duckietown.teacher import PurePursuitExpert
            env = launch_env()
            env = ResizeWrapper(env)
            env = NormalizeWrapper(env)
            env = ImgWrapper(env)
            env = ActionWrapper(env)
            env = DtRewardWrapper(env)
        elif env_id in retro_envs:
            env = make_retro(game=env_id)
            #env = SuperMarioKartDiscretizer(env)
        else:
            env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id, max_episode_steps=max_steps)

        env.seed(seed + rank)

        #TODO: Figure out what todo here
        if is_atari:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(env,
                                os.path.join(log_dir, str(rank)),
                                allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env)
        elif env_id in retro_envs:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind_retro(env, frame_stack=0)
        elif len(env.observation_space.shape) == 3:
            if env_id not in ['duckietown'] and env_id not in retro_envs:
                raise NotImplementedError(
                    "CNN models work only for atari,\n"
                    "please use a custom wrapper for a custom pixel input env.\n"
                    "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        if env_id not in ['duckietown']:
            obs_shape = env.observation_space.shape
            if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
                env = TransposeImage(env, op=[2, 0, 1])

        if time:
            env = TimeFeatureWrapper(env)

        return env
コード例 #8
0
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*','',env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        # here create our own environment which should be able to handle the parallelism:
        if (env_type in {'nf-par'}):
            if env_id == 'Pendulumnf-v0':
                from gym.envs.registration import register
                register(
                    id='Pendulumnf-v0',
                    entry_point='nfunk.envs_nf.pendulum_nf:PendulumEnv',
                    max_episode_steps=200,
                )
                env = gym.make(env_id, **env_kwargs)

        else:
            env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)


    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #9
0
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*','',env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        from baselines.common.atari_wrappers import make_atari  # delayed loading of deps
        env = make_atari(env_id)
    elif env_type == 'retro':
        from baselines.common import retro_wrappers
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)


    if env_type == 'atari':
        from baselines.common.atari_wrappers import wrap_deepmind  # delayed loading of deps
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        from baselines.common import retro_wrappers
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        from baselines.common import retro_wrappers
        env = retro_wrappers.RewardScaler(env, reward_scale)
    try:
        env.giveRank(subrank=subrank)
    except Exception as exc:
        print("ignoring exception", exc, "in baselines make_env")
        pass

    return env
コード例 #10
0
    def make_env():
        #if I wanna record the game video, then add attribute "record = '.'"
        env = make_retro(game=args.game, state=args.state, scenario=args.scenario, record = '.')
        env = wrap_deepmind_retro(env)
        env = Monitor(env, None, True)
        # And activate all of the under codes
        env.reset()
        while True:
            _obs, _rew, done, _info = env.step(env.action_space.sample())
            if done:
                break

        return env
コード例 #11
0
def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)

    if env_type == 'atari':
        if alg == 'acer':
            env = make_vec_env(env_id, env_type, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
                                        use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    else:
        env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale,
                        steps_until_done=args.env_steps, cont=args.env_cont, norm=args.env_norm,
                        start_index=args.start_index)

        if env_type == 'mujoco':
            env = VecNormalize(env)
    return env
コード例 #12
0
ファイル: cmd_util.py プロジェクト: ProFrenchToast/comp300
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*','',env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        env = FlattenObservation(env)

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)


    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #13
0
def build_env(args, selector=None):
    global store
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)
    print(env_type, env_id, nenv, args.num_env)
    if env_type == 'mujoco':
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        if args.num_env:
            env = SubprocVecEnv([
                lambda: make_mujoco_env(env_id, seed + i if seed is not None
                                        else None, args.reward_scale)
                for i in range(args.num_env)
            ])
        else:
            env = DummyVecEnv(
                [lambda: make_mujoco_env(env_id, seed, args.reward_scale)])

        env = VecNormalize(env)

    elif env_type == 'atari':
        if alg == 'acer':
            env = make_atari_env(
                env_id, nenv, seed)  #, wrapper_kwargs={'clip_rewards': False})
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env,
                                               frame_stack=True,
                                               scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        elif "Zelda" in env_id:
            sys.path.append(
                "/home/jupyter/Notebooks/Chang/HardRLWithYoutube/nnrunner/a2c_gvgai"
            )
            import nnrunner.a2c_gvgai.env as gvgai_env
            frame_stack_size = 4
            print("run zelda")
            env = VecFrameStack(
                gvgai_env.make_gvgai_env(env_id,
                                         nenv,
                                         seed,
                                         level_selector=selector,
                                         experiment="PE",
                                         dataset="zelda"), frame_stack_size)
            # env.reset()
            # store = env
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_atari_env(env_id, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    elif env_type == 'classic_control':

        def make_env():
            e = gym.make(env_id)
            e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
            e.seed(seed)
            return e

        env = DummyVecEnv([make_env])

    else:
        raise ValueError('Unknown env_type {}'.format(env_type))

    # env.reset()
    print("build env")
    # store.reset()
    # store.reset()

    return env
コード例 #14
0
def make_env(env_id,
             env_type,
             subrank=0,
             seed=None,
             reward_scale=1.0,
             gamestate=None,
             flatten_dict_observations=True,
             wrapper_kwargs=None):
    mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    wrapper_kwargs = wrapper_kwargs or {}
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(
            game=env_id,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE,
            state=gamestate)
    elif env_type == 'starcraft2':
        import sc2gym.envs
        from absl import flags
        from pysc2.lib import point_flag
        from pysc2.env import sc2_env

        FLAGS = flags.FLAGS
        FLAGS([__file__])
        env = gym.make(env_id)
        env.settings['visualize'] = True
        env.settings[
            'agent_interface_format'] = sc2_env.parse_agent_interface_format(
                feature_screen=32,
                feature_minimap=32,
                rgb_screen=None,
                rgb_minimap=None,
                action_space="features",
                use_feature_units=False)
    else:
        env = gym.make(env_id)

    if flatten_dict_observations and isinstance(env.observation_space,
                                                gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger.get_dir()
                  and os.path.join(logger.get_dir(),
                                   str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #15
0
ファイル: run.py プロジェクト: williamd4112/baselines
def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)
    if env_type == 'mujoco':
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        if args.num_env:
            env = SubprocVecEnv([
                lambda: make_mujoco_env(env_id, seed + i if seed is not None
                                        else None, args.reward_scale)
                for i in range(args.num_env)
            ])
        else:
            env = DummyVecEnv(
                [lambda: make_mujoco_env(env_id, seed, args.reward_scale)])

        env = VecNormalize(env)

    elif env_type == 'atari':
        if alg == 'acer':
            env = make_atari_env(env_id, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env,
                                               frame_stack=True,
                                               scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_atari_env(env_id, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    elif env_type == 'classic_control':

        def make_env():
            e = gym.make(env_id)
            e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
            e.seed(seed)
            return e

        env = DummyVecEnv([make_env])

    else:
        raise ValueError('Unknown env_type {}'.format(env_type))

    return env
コード例 #16
0
def make_env(env_id,
             env_type,
             mpi_rank=0,
             subrank=0,
             seed=None,
             reward_scale=1.0,
             gamestate=None,
             flatten_dict_observations=True,
             wrapper_kwargs=None,
             env_kwargs=None,
             logger_dir=None,
             initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*', '', env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(
            game=env_id,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE,
            state=gamestate)
    elif env_type == 'robotics':
        env = gym.make(env_id)
        env = FlattenDictWrapper(
            env, ['observation', 'achieved_goal', 'desired_goal'])
    else:
        if env_id == 'LunarLanderContinuousPOMDP-v0':
            new_lunar_lander_pomdp_env(hist_len=hist_len,
                                       block_high=block_high,
                                       not_guided=not_guided,
                                       give_state=give_state)
        else:
            env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space,
                                                gym.spaces.Dict):
        keys = env.observation_space.spaces.keys()
        env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir
                  and os.path.join(logger_dir,
                                   str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)

    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env
コード例 #17
0
def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)

    if env_type == 'atari':
        if alg == 'acer':
            env = make_vec_env(env_id, env_type, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env,
                                               frame_stack=True,
                                               scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    elif env_type == 'AirHockey':
        from gym_airhockey.configuration import configure_env
        from baselines.common.vec_env.dummy_vec_env import DummyVecEnv

        version_list = [x for x in args.versions if x is not None]
        version = version_list[
            MPI.COMM_WORLD.Get_rank() %
            len(version_list)]  # Each rank gets its own version

        # setup the environment
        env = gym.make(env_id)
        env.seed(args.seed)
        configure_env(env, version=version)

        # wrap the environment
        env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True)
        env = DummyVecEnv([lambda: env])
        env.render()

    else:
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        env = make_vec_env(env_id,
                           env_type,
                           args.num_env or 1,
                           seed,
                           reward_scale=args.reward_scale)

        if env_type == 'mujoco':
            env = VecNormalize(env)

    return env
コード例 #18
0
ファイル: ppo.py プロジェクト: laurenmoos/weird-mario
 def make_env():
     env = make_retro(game=args.game,
                      state=args.state,
                      scenario=args.scenario)
     env = wrap_deepmind_retro(env)
     return env