コード例 #1
0
ファイル: make_env.py プロジェクト: duynguyen158/wann-nlp
def make_env(env_name, encoder, max_features, seed=-1, render_mode=False):

    # -- Classification ------------------------------------------------ -- #
    if (env_name.startswith("Classify")):
        from domain.classify_gym import ClassifyEnv
        if env_name.endswith("spam"):
            from domain.classify_gym import spam
            trainSet, target = spam(encoder, max_features)
            env = ClassifyEnv(trainSet, target, 75)
        elif env_name.endswith("imdb"):
            from domain.classify_gym import imdb
            trainSet, target = imdb(encoder, max_features)
            env = ClassifyEnv(trainSet, target, 400)
        elif env_name.endswith("speech_mnist"):
            from domain.classify_gym import speech_mnist
            trainSet, target = speech_mnist()
            env = ClassifyEnv(trainSet, target, 512)
        elif env_name.endswith("speech_yesno"):
            from domain.classify_gym import speech_yesno
            trainSet, target = speech_yesno()
            env = ClassifyEnv(trainSet, target, 512)
    if (seed >= 0):
        domain.seed(seed)

    return env
コード例 #2
0
def make_env(env_name, seed=-1, render_mode=False):
  if "Bullet" in env_name:
    import pybullet as p
    import pybullet_envs
    import pybullet_envs.bullet.kukaGymEnv as kukaGymEnv


  # -- Bipedal Walker ------------------------------------------------ -- #
  if (env_name.startswith("BipedalWalker")):
    if (env_name.startswith("BipedalWalkerHardcore")):
      import Box2D
      from domain.bipedal_walker import BipedalWalkerHardcore
      env = BipedalWalkerHardcore()
    elif (env_name.startswith("BipedalWalkerMedium")): 
      from domain.bipedal_walker import BipedalWalker
      env = BipedalWalker()
      env.accel = 3
    else:
      from domain.bipedal_walker import BipedalWalker
      env = BipedalWalker()


  # -- VAE Racing ---------------------------------------------------- -- #
  elif (env_name.startswith("VAERacing")):
    from domain.vae_racing import VAERacing
    env = VAERacing()
    
    
  # -- Classification ------------------------------------------------ -- #
  elif (env_name.startswith("Classify")):
    from domain.classify_gym import ClassifyEnv
    if env_name.endswith("digits"):
      from domain.classify_gym import digit_raw
      trainSet, target  = digit_raw()
    
    if env_name.endswith("mnist256"):
      from domain.classify_gym import mnist_256
      trainSet, target  = mnist_256()
    
    if env_name.endswith("fashionmnist"):
      from domain.classify_gym import fashion_mnist
      trainSet, target = fashion_mnist()
  
    env = ClassifyEnv(trainSet,target)  
      

  # -- Cart Pole Swing up -------------------------------------------- -- #
  elif (env_name.startswith("CartPoleSwingUp")):
    from domain.cartpole_swingup import CartPoleSwingUpEnv
    env = CartPoleSwingUpEnv()


  # -- Other  -------------------------------------------------------- -- #
  else:
    env = gym.make(env_name)

  if (seed >= 0):
    domain.seed(seed)

  return env
コード例 #3
0
def make_env(env_name, seed=-1, render_mode=False):
    # -- Bullet Environments ------------------------------------------- -- #
    if "Bullet" in env_name:
        import pybullet as p  # pip install pybullet
        import pybullet_envs
        import pybullet_envs.bullet.kukaGymEnv as kukaGymEnv

    # -- Bipedal Walker ------------------------------------------------ -- #
    if (env_name.startswith("BipedalWalker")):
        if (env_name.startswith("BipedalWalkerHardcore")):
            import Box2D
            from domain.bipedal_walker import BipedalWalkerHardcore
            env = BipedalWalkerHardcore()
        elif (env_name.startswith("BipedalWalkerMedium")):
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()
            env.accel = 3
        else:
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()

    # -- Custom control tasks for pivector work ------------------------ -- #
    elif env_name in ("Pendulum-v0", "CartPole-v1", "Acrobot-v1",
                      "LunarLander-v2"):
        env = gym.make(env_name)

    # -- VAE Racing ---------------------------------------------------- -- #
    elif (env_name.startswith("VAERacing")):
        from domain.vae_racing import VAERacing
        env = VAERacing()

    # -- Classification ------------------------------------------------ -- #
    elif (env_name.startswith("Classify")):
        from domain.classify_gym import ClassifyEnv
        if env_name.endswith("digits"):
            from domain.classify_gym import digit_raw
            trainSet, target = digit_raw()

        if env_name.endswith("mnist256"):
            from domain.classify_gym import mnist_256
            trainSet, target = mnist_256()

        env = ClassifyEnv(trainSet, target)

    # -- Cart Pole Swing up -------------------------------------------- -- #
    elif (env_name.startswith("CartPoleSwingUp")):
        from domain.cartpole_swingup import CartPoleSwingUpEnv
        env = CartPoleSwingUpEnv()
        if (env_name.startswith("CartPoleSwingUp_Hard")):
            env.dt = 0.01
            env.t_limit = 200

    # -- Other  -------------------------------------------------------- -- #
    else:
        env = gym.make(env_name)

    if (seed >= 0):
        domain.seed(seed)

    return env
コード例 #4
0
def make_env(env_name, seed=-1, render_mode=False):

    # -- Bipedal Walker ------------------------------------------------ -- #
    if (env_name.startswith("BipedalWalker")):
        if (env_name.startswith("BipedalWalkerHardcore")):
            import Box2D
            from domain.bipedal_walker import BipedalWalkerHardcore
            env = BipedalWalkerHardcore()
        elif (env_name.startswith("BipedalWalkerMedium")):
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()
            env.accel = 3
        else:
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()

    # -- VAE Racing ---------------------------------------------------- -- #
    elif (env_name.startswith("VAERacing")):
        from domain.vae_racing import VAERacing
        env = VAERacing()

    # -- Classification ------------------------------------------------ -- #
    elif (env_name.startswith("Classify")):
        from domain.classify_gym import ClassifyEnv
        if env_name.endswith("digits"):
            from domain.classify_gym import digit_raw
            trainSet, target = digit_raw()

        if env_name.endswith("mnist784"):
            from domain.classify_gym import mnist_784
            trainSet, target = mnist_784()

        if env_name.endswith("mnist256"):
            from domain.classify_gym import mnist_256
            trainSet, target = mnist_256()

        env = ClassifyEnv(trainSet, target)

    # -- Cart Pole Swing up -------------------------------------------- -- #
    elif (env_name.startswith("CartPoleSwingUp")):
        from domain.cartpole_swingup import CartPoleSwingUpEnv
        env = CartPoleSwingUpEnv()
        env = Monitor(env,
                      './video/',
                      video_callable=lambda episode_id: True,
                      force=True)
        if (env_name.startswith("CartPoleSwingUp_Hard")):
            env.dt = 0.01
            env.t_limit = 200

    # -- Other  -------------------------------------------------------- -- #
    else:
        env = gym.make(env_name)
        # env = Monitor(env, './video/',video_callable=lambda episode_id: True,force = True)

    if (seed >= 0):
        domain.seed(seed)

    return env
コード例 #5
0
def make_env(env_name, seed=-1, render_mode=False):

    # -- Bipedal Walker ------------------------------------------------ -- #
    if (env_name.startswith("BipedalWalker")):
        if (env_name.startswith("BipedalWalkerHardcore")):
            import Box2D
            from domain.bipedal_walker import BipedalWalkerHardcore
            env = BipedalWalkerHardcore()
        elif (env_name.startswith("BipedalWalkerMedium")):
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()
            env.accel = 3
        else:
            from domain.bipedal_walker import BipedalWalker
            env = BipedalWalker()

    # -- VAE Racing ---------------------------------------------------- -- #
    elif (env_name.startswith("VAERacing")):
        from domain.vae_racing import VAERacing
        env = VAERacing()

    # -- Classification ------------------------------------------------ -- #
    elif (env_name.startswith("Classify")):
        from domain.classify_gym import ClassifyEnv
        if env_name.endswith("digits"):
            from domain.classify_gym import digit_raw
            trainSet, target = digit_raw()

        if env_name.endswith("mnist784"):
            from domain.classify_gym import mnist_784
            trainSet, target = mnist_784()

        if env_name.endswith("mnist256train"):
            from domain.classify_gym import mnist_256
            trainSet, target = mnist_256()
            env = ClassifyEnv(trainSet,
                              target,
                              batch_size=1000,
                              accuracy_mode=False)
        if env_name.endswith("mnist256test"):
            from domain.classify_gym import mnist_256_test
            test_images, test_labels = mnist_256_test()
            env = ClassifyEnv(test_images,
                              test_labels,
                              batch_size=1000,
                              accuracy_mode=False)

    # -- Cart Pole Swing up -------------------------------------------- -- #
    elif (env_name.startswith("CartPoleSwingUp")):
        from domain.cartpole_swingup import CartPoleSwingUpEnv
        env = CartPoleSwingUpEnv()
        if (env_name.startswith("CartPoleSwingUp_Hard")):
            env.dt = 0.01
            env.t_limit = 200

    # -- Other  -------------------------------------------------------- -- #
    else:
        env = gym.make(env_name)

    if (seed >= 0):
        domain.seed(seed)

    return env
コード例 #6
0
def make_env(env_name, seed=-1, render_mode=False):

  # -- Bipedal Walker ------------------------------------------------ -- #
  if (env_name.startswith("BipedalWalker")):
    if (env_name.startswith("BipedalWalkerHardcore")):
      import Box2D
      from domain.bipedal_walker import BipedalWalkerHardcore
      env = BipedalWalkerHardcore()
    elif (env_name.startswith("BipedalWalkerMedium")): 
      from domain.bipedal_walker import BipedalWalker
      env = BipedalWalker()
      env.accel = 3
    else:
      from domain.bipedal_walker import BipedalWalker
      env = BipedalWalker()


  # -- VAE Racing ---------------------------------------------------- -- #
  elif (env_name.startswith("VAERacing")):
    from domain.vae_racing import VAERacing
    env = VAERacing()
    
  # -- Classification ------------------------------------------------ -- #
  elif (env_name.startswith("Classify")):
    from domain.classify_gym import ClassifyEnv
    if env_name.endswith("digits"):
      from domain.classify_gym import digit_raw
      trainSet, target  = digit_raw()
    
    if env_name.endswith("mnist784"):
      from domain.classify_gym import mnist_784
      trainSet, target  = mnist_784()
    
    if env_name.endswith("mnist256"):
      from domain.classify_gym import mnist_256
      trainSet, target  = mnist_256()

    env = ClassifyEnv(trainSet,target)  


  # -- Cart Pole Swing up -------------------------------------------- -- #
  elif (env_name.startswith("CartPoleSwingUp")):
    print("hi %s" %  env_name)
    if (env_name.endswith("alt")):
      from domain.cartpole_swingup_altered import CartPoleSwingUpEnv
      env = CartPoleSwingUpEnv()
    elif env_name.endswith("simple"):
      print("yep")
      from domain.cartpole_swingup_simplified import CartPoleSwingUpSimpleEnv
      env = CartPoleSwingUpSimpleEnv()
    else:
      from domain.cartpole_swingup import CartPoleSwingUpEnv
      env = CartPoleSwingUpEnv()
    
    if (env_name.startswith("CartPoleSwingUp_Hard")):
      env.dt = 0.01
      env.t_limit = 200




  # -- Other  -------------------------------------------------------- -- #

  elif (env_name.startswith("SDF")):
    from domain.regression_gym import RegressionEnv
    env = RegressionEnv()

  elif (env_name.startswith("gates")):
    from domain.gate_gym import LogicGateEnv
    gate = env_name[len("gates_"): -len("-v1")]
    env = LogicGateEnv(gate)

  else:
    env = gym.make(env_name)

  if (seed >= 0):
    domain.seed(seed)

  return env