コード例 #1
0
 def __init__(self, env_name, trace_path=None, fps=25):
     self.env_name = env_name
     self.env = gym.make(env_name)
     if not env_name.startswith('vgdl'):
         logger.debug('Assuming Atari env, enable AtariObservationWrapper')
         self.env = AtariObservationWrapper(self.env)
     if trace_path:
         self.env = TraceRecordingWrapper(self.env, trace_path)
     self.fps = fps
     self.cum_reward = 0
コード例 #2
0
    def __init__(self, env_name, trace_path=None, fps=15):
        self.env_name = env_name
        self.env = gym.make(env_name)
        if not env_name.startswith('vgdl'):
            logger.debug('Assuming Atari env, enable AtariObservationWrapper')
            from .wrappers import AtariObservationWrapper
            self.env = AtariObservationWrapper(self.env)
        if trace_path is not None and importlib.util.find_spec(
                'gym_recording') is not None:
            from gym_recording.wrappers import TraceRecordingWrapper
            self.env = TraceRecordingWrapper(self.env, trace_path)
        elif trace_path is not None:
            logger.warn(
                'trace_path provided but could not find the gym_recording package'
            )

        self.fps = fps
        self.cum_reward = 0
コード例 #3
0
def test_trace_recording():

    env = gym.make('CartPole-v0')
    env = TraceRecordingWrapper(env)
    recdir = env.directory
    agent = lambda ob: env.action_space.sample()

    for epi in range(10):
        ob = env.reset()
        for _ in range(100):
            assert env.observation_space.contains(ob)
            a = agent(ob)
            assert env.action_space.contains(a)
            (ob, _reward, done, _info) = env.step(a)
            if done: break
    env.close()

    counts = [0, 0]

    def handle_ep(observations, actions, rewards):
        counts[0] += 1
        counts[1] += observations.shape[0]
        logger.debug(
            'Observations.shape={}, actions.shape={}, rewards.shape={}',
            observations.shape, actions.shape, rewards.shape)

    scan_recorded_traces(recdir, handle_ep)
    assert counts[0] == 10
    assert counts[1] > 100
コード例 #4
0
def expert_agent():
    env = gym.make('CoGLEM1-virtual-v0')
    os.makedirs('./traces', exist_ok=True)
    env = TraceRecordingWrapper(env, directory='./traces/', buffer_batch_size=10)
    
    ITERATIONS = 10

    for x in range(ITERATIONS):
        obs = env.reset()
        done = False
        while not done:
                env.render()
                action = calculate_expert_action(obs)
                print('x: ',x, 'Doing action: ', action, ' ', env.env._elapsed_steps, '\r')
                obs, reward, done, info = env.step(action)
                print('observations: ', obs, ' ', reward, ' ', done, '\r')

    env.close()
コード例 #5
0
class HumanController:
    def __init__(self, env_name, trace_path=None, fps=25):
        self.env_name = env_name
        self.env = gym.make(env_name)
        if not env_name.startswith('vgdl'):
            logger.debug('Assuming Atari env, enable AtariObservationWrapper')
            self.env = AtariObservationWrapper(self.env)
        if trace_path:
            self.env = TraceRecordingWrapper(self.env, trace_path)
        self.fps = fps
        self.cum_reward = 0

    def play(self, pause_on_finish=False):
        self.env.reset()

        for step_i in itertools.count():
            # Only does something for VGDL because Atari's Pyglet is event-based
            self.controls.capture_key_presses()

            obs, reward, done, info = self.env.step(
                self.controls.current_action)
            if reward:
                logger.debug("reward %0.3f" % reward)

            self.cum_reward += reward
            window_open = self.env.render()

            self.after_step(self.env.unwrapped.game.time)

            if not window_open:
                logger.debug('Window closed')
                return False

            if done:
                logger.debug('===> Done!')
                if pause_on_finish:
                    self.controls.pause = True
                    pause_on_finish = False
                else:
                    break

            if self.controls.restart:
                logger.info('Requested restart')
                self.controls.restart = False
                break

            if self.controls.debug:
                self.controls.debug = False
                self.debug()
                continue

            while self.controls.pause:
                self.controls.capture_key_presses()
                self.env.render()
                time.sleep(1. / self.fps)

            time.sleep(1. / self.fps)

    def debug(self, *args, **kwargs):
        # Convenience debug breakpoint
        env = self.env.unwrapped
        game = env.game
        observer = env.observer
        obs = env.observer.get_observation()
        sprites = game.sprite_registry
        state = game.getGameState()
        all = dict(env=env,
                   game=game,
                   observer=observer,
                   obs=obs,
                   sprites=sprites,
                   state=state)
        print(all)

        import ipdb
        ipdb.set_trace()

    def after_step(self, step):
        pass
コード例 #6
0
    parser.add_argument('--switch-action-every', type=int, default=1)
    parser.add_argument('env_id',
                        nargs='?',
                        default='CrossCircle-MixedRand-v0')
    args = parser.parse_args()

    os.makedirs(args.directory, exist_ok=True)
    with open(os.path.join(args.directory, 'config.txt'), 'w') as f:
        f.write(pprint.pformat(args))

    env = gym.make(args.env_id)

    os.makedirs(args.directory, exist_ok=True)

    env = TraceRecordingWrapper(env,
                                directory=args.directory,
                                episode_filter=Filter(1),
                                frame_filter=Filter(1))

    env.seed(0)
    agent = RandomAgent(env.action_space, args.switch_action_every)

    reward = 0
    done = False

    ob = env.reset()
    for episode in range(args.n_episodes):
        env.reset()
        # env.render()
        for step in range(args.n_steps):
            action = agent.act(ob, reward, done)
            ob, reward, done, info = env.step(action)
コード例 #7
0
# In[1]:

#!/usr/bin/env python
from __future__ import print_function

import sys, gym, time
from gym_recording.wrappers import TraceRecordingWrapper
#
# Test yourself as a learning agent! Pass environment name as a command-line argument, for example:
#
# python keyboard_agent.py SpaceInvadersNoFrameskip-v4
#

env = gym.make('Freeway-v0')
env = TraceRecordingWrapper(env, 'record')
if not hasattr(env.action_space, 'n'):
    raise Exception('Keyboard agent only supports discrete action spaces')
ACTIONS = env.action_space.n
SKIP_CONTROL = 0  # Use previous control decision SKIP_CONTROL times, that's how you
# can test what skip is still usable.

human_agent_action = 0
human_wants_restart = False
human_sets_pause = False


def key_press(key, mod):
    global human_agent_action, human_wants_restart, human_sets_pause
    if key == 0xff0d: human_wants_restart = True
    if key == 32: human_sets_pause = not human_sets_pause
コード例 #8
0
start_time = time.time()

print("CUDA available: {}\n".format(torch.cuda.is_available()))
for difficulty in range(1, args.difficulty + 1):
    logs_by_difficulty = {"difficulty": difficulty, "actual_difficulty": [], "target_path_length": [], "num_frames_per_episode": [], "return_per_episode": []}
    for i in range(args.episodes):
        episode_return = 0
        episode_num_frames = 0

        env = gym.make(args.env)
        env.random_seed = i
        env.seed(i)
        utils.seed(i)
        env.unwrapped.set_difficulty(difficulty)

        env = TraceRecordingWrapper(env, directory="storage/recordings")
        obs = env.reset()

        model_dir = utils.get_model_dir(args.env, args.model, args.seed)
        agent = utils.Agent(args.env, env.observation_space, model_dir, args.argmax, 1)
        done = False
        while not done:
            action = agent.get_action(obs)
            obs, reward, done, _ = env.step(action)
            agent.analyze_feedback(reward, done)

            episode_return += reward
            episode_num_frames += 1

            if done:
                logs_by_difficulty["actual_difficulty"].append(env.unwrapped.actual_difficulty)
コード例 #9
0
ファイル: human.py プロジェクト: ChaceHayhurst/py-vgdl
class HumanController:
    def __init__(self, env_name, trace_path=None, fps=15):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.trace_path = trace_path
        if not env_name.startswith('vgdl'):
            logger.debug('Assuming Atari env, enable AtariObservationWrapper')
            from .wrappers import AtariObservationWrapper
            self.env = AtariObservationWrapper(self.env)
        if trace_path is not None and importlib.util.find_spec(
                'gym_recording') is not None:
            from gym_recording.wrappers import TraceRecordingWrapper
            self.env = TraceRecordingWrapper(self.env, trace_path)
        elif trace_path is not None:
            logger.warn(
                'trace_path provided but could not find the gym_recording package'
            )

        self.fps = fps
        self.cum_reward = 0

    def play(self, pause_on_finish=False, pause_on_start=False):
        self.env.reset()

        for step_i in itertools.count():
            if pause_on_start:
                self.controls.pause = True
                pause_on_start = False

            # Only does something for VGDL because Atari's Pyglet is event-based
            self.controls.capture_key_presses()

            obs, reward, done, info = self.env.step(
                self.controls.current_action)
            if reward:
                logger.debug("reward %0.3f" % reward)
            f = open(self.trace_path + "/game.txt", "a+")
            f.write("Observation" + str(obs) + ", Reward: " + str(reward) +
                    " Done?: " + str(done) + '\n')
            f.close()

            self.cum_reward += reward
            window_open = self.env.render()

            self.after_step(self.env.unwrapped.game.time)

            if not window_open:
                logger.debug('Window closed')
                return False

            if done:
                logger.debug('===> Done!')
                if pause_on_finish:
                    self.controls.pause = True
                    pause_on_finish = False
                else:
                    break

            if self.controls.restart:
                logger.info('Requested restart')
                self.controls.restart = False
                break

            if self.controls.debug:
                self.controls.debug = False
                self.debug()
                continue

            while self.controls.pause:
                self.controls.capture_key_presses()
                self.env.render()
                time.sleep(1. / self.fps)

            time.sleep(1. / self.fps)

    def debug(self, *args, **kwargs):
        # Convenience debug breakpoint
        env = self.env.unwrapped
        game = env.game
        observer = env.observer
        obs = env.observer.get_observation()
        sprites = game.sprite_registry
        state = game.get_game_state()
        all = dict(env=env,
                   game=game,
                   observer=observer,
                   obs=obs,
                   sprites=sprites,
                   state=state)
        print(all)

        import ipdb
        ipdb.set_trace()

    def after_step(self, step):
        pass
コード例 #10
0
def run(env_id, seed, noise_type, layer_norm, evaluation, custom_log_dir,
        **kwargs):
    # Configure things.
    rank = MPI.COMM_WORLD.Get_rank()
    if rank != 0:
        logger.set_level(logger.DISABLED)

    train_recording_path = os.path.join(
        custom_log_dir, env_id, 'train',
        datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(train_recording_path)

    # Create envs.
    env = gym.make(env_id)
    env = TraceRecordingWrapper(env,
                                directory=train_recording_path,
                                buffer_batch_size=10)
    logger.info('TraceRecordingWrapper dir: {}'.format(env.directory))
    # env = bench.Monitor(env, os.path.join(train_recording_path, 'log'))

    if evaluation and rank == 0:
        eval_recording_path = os.path.join(
            custom_log_dir, env_id, 'eval',
            datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
        os.makedirs(eval_recording_path)

        eval_env = gym.make(env_id)
        eval_env = TraceRecordingWrapper(eval_env,
                                         directory=eval_recording_path,
                                         buffer_batch_size=10)
        logger.info('TraceRecordingWrapper eval dir: {}'.format(
            eval_env.directory))
        # eval_env = bench.Monitor(eval_env, os.path.join(logger.get_dir(), 'gym_eval'))
        # env = bench.Monitor(env, None)
    else:
        eval_env = None

    # Parse noise_type
    action_noise = None
    param_noise = None
    nb_actions = env.action_space.shape[-1]
    for current_noise_type in noise_type.split(','):
        current_noise_type = current_noise_type.strip()
        if current_noise_type == 'none':
            pass
        elif 'adaptive-param' in current_noise_type:
            _, stddev = current_noise_type.split('_')
            param_noise = AdaptiveParamNoiseSpec(
                initial_stddev=float(stddev),
                desired_action_stddev=float(stddev))
        elif 'normal' in current_noise_type:
            _, stddev = current_noise_type.split('_')
            action_noise = NormalActionNoise(mu=np.zeros(nb_actions),
                                             sigma=float(stddev) *
                                             np.ones(nb_actions))
        elif 'ou' in current_noise_type:
            _, stddev = current_noise_type.split('_')
            action_noise = OrnsteinUhlenbeckActionNoise(
                mu=np.zeros(nb_actions),
                sigma=float(stddev) * np.ones(nb_actions))
        else:
            raise RuntimeError(
                'unknown noise type "{}"'.format(current_noise_type))

    # Configure components.
    memory = Memory(limit=int(1e6),
                    action_shape=env.action_space.shape,
                    observation_shape=env.observation_space.shape)
    critic = Critic(layer_norm=layer_norm)
    actor = Actor(nb_actions, layer_norm=layer_norm)

    # Seed everything to make things reproducible.
    seed = seed + 1000000 * rank
    logger.info('DDPG: rank {}: seed={}, logdir={}'.format(
        rank, seed, logger.get_dir()))
    tf.reset_default_graph()
    set_global_seeds(seed)
    env.seed(seed)
    if eval_env is not None:
        eval_env.seed(seed)

    # Disable logging for rank != 0 to avoid noise.
    if rank == 0:
        start_time = time.time()
    training.train(env=env,
                   eval_env=eval_env,
                   param_noise=param_noise,
                   action_noise=action_noise,
                   actor=actor,
                   critic=critic,
                   memory=memory,
                   **kwargs)
    env.close()
    if eval_env is not None:
        eval_env.close()
    if rank == 0:
        logger.info('total runtime: {}s'.format(time.time() - start_time))