Beispiel #1
0
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from agents.deep_q_agent import DeepQAgent
from gym.wrappers import ResizeObservation
from max_frameskip_env import MaxFrameskipEnv
from reward_cache_env import RewardCacheEnv
from penalize_death_env import PenalizeDeathEnv

import tensorflow as tf
import keras

env = gym_super_mario_bros.make('SuperMarioBros-v1')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = RewardCacheEnv(env)
env = ResizeObservation(env, (64, 64))
env = PenalizeDeathEnv(env)
env = MaxFrameskipEnv(env)

config = tf.ConfigProto(device_count={'GPU': 1, 'CPU': 4})
sess = tf.Session(config=config)
keras.backend.set_session(sess)

#env = gym.make('SpaceInvaders-v0')

agent = DeepQAgent(env=env,
                   render_mode=None,
                   dueling_network=True,
                   prioritized_experience_replay=True)

agent.train()
Beispiel #2
0
    parser = argparse.ArgumentParser(description=None)
    # Question 1
    # Change the var called environment_used to change the environment
    parser.add_argument('env_id', nargs='?', default=environment_used, help='Select the environment to run')
    args = parser.parse_args()
    
    # You can set the level to logger.DEBUG or logger.WARN if you
    # want to change the amount of output.
    logger.set_level(logger.INFO)

    env = gym.make(args.env_id)

    if environment_used != "CartPole-v1":
        env = GrayScaleObservation(env)
        env = ResizeObservation(env, resolution)
        env = FrameStack(env, 4)
        agent = VizdoomAgent(env.action_space, resolution , eta, test_mode, environment_used)
    else:
        agent = RandomAgent(env.action_space, env.observation_space, eta, test_mode, environment_used)

    # You provide the directory to write to (can be an existing
    # directory, including one with existing data -- all monitor files
    # will be namespaced). You can also dump to a tempdir if you'd
    # like: tempfile.mkdtemp().
    outdir = '/tmp/random-agent-results'
    env = wrappers.Monitor(env, directory=outdir, force=True)
    env.seed(0)

    mem = Memory()
Beispiel #3
0
def make_env_all_params(rank, args):
    """Initialize the environment and apply wrappers.

    Parameters
    ----------
    rank :
        Rank of the environment.
    args :
        Hyperparameters for this run.

    Returns
    -------
    env
        Environment with its individual wrappers.

    """
    if args.env_kind == "atari":
        from stable_baselines3.common.atari_wrappers import NoopResetEnv

        from nupic.embodied.disagreement.envs.wrappers import (
            AddRandomStateToInfo,
            ExtraTimeLimit,
            FrameStack,
            MaxAndSkipEnv,
            MontezumaInfoWrapper,
            ProcessFrame84,
            StickyActionEnv,
        )
        env = gym.make(args.env)
        assert "NoFrameskip" in env.spec.id
        if args.stickyAtari:
            env._max_episode_steps = args.max_episode_steps * 4
            env = StickyActionEnv(env)
        else:
            env = NoopResetEnv(env, noop_max=args.noop_max)
        env = MaxAndSkipEnv(env, skip=4)
        env = ProcessFrame84(env, crop=False)
        env = FrameStack(env, 4)
        if not args.stickyAtari:
            env = ExtraTimeLimit(env, args.max_episode_steps)
        if "Montezuma" in args.env:
            env = MontezumaInfoWrapper(env)
        env = AddRandomStateToInfo(env)
    elif args.env_kind == "mario":
        from nupic.embodied.disagreement.envs.wrappers import make_mario_env
        env = make_mario_env()
    elif args.env_kind == "retro_multi":
        from nupic.embodied.disagreement.envs.wrappers import make_multi_pong
        env = make_multi_pong()
    elif args.env_kind == "roboarm":
        from real_robots.envs import REALRobotEnv

        from nupic.embodied.disagreement.envs.wrappers import CartesianControlDiscrete
        env = REALRobotEnv(objects=3, action_type="cartesian")
        env = CartesianControlDiscrete(
            env,
            crop_obs=args.crop_obs,
            repeat=args.act_repeat,
            touch_reward=args.touch_reward,
            random_force=args.random_force,
        )
        if args.resize_obs > 0:
            env = ResizeObservation(env, args.resize_obs)

    print("adding monitor")
    env = Monitor(env, filename=None)
    return env