Exemplo n.º 1
0
def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)
    if env_type == 'mujoco':
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        if args.num_env:
            env = SubprocVecEnv([
                lambda: make_mujoco_env(env_id, seed + i if seed is not None
                                        else None, args.reward_scale)
                for i in range(args.num_env)
            ])
        else:
            env = DummyVecEnv(
                [lambda: make_mujoco_env(env_id, seed, args.reward_scale)])

        env = VecNormalize(env)

    elif env_type == 'atari':
        if alg == 'acer':
            env = make_atari_env(env_id, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env,
                                               frame_stack=True,
                                               scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_atari_env(env_id, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    elif env_type == 'classic_control':

        def make_env():
            e = gym.make(env_id)
            e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
            e.seed(seed)
            return e

        env = DummyVecEnv([make_env])

    else:
        raise ValueError('Unknown env_type {}'.format(env_type))

    return env
Exemplo n.º 2
0
def build_env(args, selector=None):
    global store
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)
    print(env_type, env_id, nenv, args.num_env)
    if env_type == 'mujoco':
        get_session(
            tf.ConfigProto(allow_soft_placement=True,
                           intra_op_parallelism_threads=1,
                           inter_op_parallelism_threads=1))

        if args.num_env:
            env = SubprocVecEnv([
                lambda: make_mujoco_env(env_id, seed + i if seed is not None
                                        else None, args.reward_scale)
                for i in range(args.num_env)
            ])
        else:
            env = DummyVecEnv(
                [lambda: make_mujoco_env(env_id, seed, args.reward_scale)])

        env = VecNormalize(env)

    elif env_type == 'atari':
        if alg == 'acer':
            env = make_atari_env(
                env_id, nenv, seed)  #, wrapper_kwargs={'clip_rewards': False})
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env,
                                               frame_stack=True,
                                               scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(
                env,
                logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        elif "Zelda" in env_id:
            sys.path.append(
                "/home/jupyter/Notebooks/Chang/HardRLWithYoutube/nnrunner/a2c_gvgai"
            )
            import nnrunner.a2c_gvgai.env as gvgai_env
            frame_stack_size = 4
            print("run zelda")
            env = VecFrameStack(
                gvgai_env.make_gvgai_env(env_id,
                                         nenv,
                                         seed,
                                         level_selector=selector,
                                         experiment="PE",
                                         dataset="zelda"), frame_stack_size)
            # env.reset()
            # store = env
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_atari_env(env_id, nenv, seed),
                                frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(
            game=args.env,
            state=gamestate,
            max_episode_steps=10000,
            use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    elif env_type == 'classic_control':

        def make_env():
            e = gym.make(env_id)
            e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
            e.seed(seed)
            return e

        env = DummyVecEnv([make_env])

    else:
        raise ValueError('Unknown env_type {}'.format(env_type))

    # env.reset()
    print("build env")
    # store.reset()
    # store.reset()

    return env