Ejemplo n.º 1
0
 def step(self, action):
     assert self.action_space.contains(action)
     self.last_action = action
     inp_act, out_act, pred = action
     done = False
     reward = 0.0
     self.time += 1
     assert 0 <= self.write_head_position
     if out_act == 1:
         try:
             correct = pred == self.target[self.write_head_position]
         except IndexError:
             logger.warn(
                 "It looks like you're calling step() even though this "
                 "environment has already returned done=True. You should "
                 "always call reset() once you receive done=True. Any "
                 "further steps are undefined behaviour.")
             correct = False
         if correct:
             reward = 1.0
         else:
             # Bail as soon as a wrong character is written to the tape
             reward = -0.5
             done = True
         self.write_head_position += 1
         if self.write_head_position >= len(self.target):
             done = True
     self._move(inp_act)
     if self.time > self.time_limit:
         reward = -1.0
         done = True
     obs = self._get_obs()
     self.last_reward = reward
     self.episode_total_reward += reward
     return (obs, reward, done, {})
Ejemplo n.º 2
0
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec.entry_point
    # Skip mujoco tests for pull request CI
    if skip_mujoco and (ep.startswith('gym.envs.mujoco')
                        or ep.startswith('gym.envs.robotics:')):
        return True
    try:
        import atari_py
    except ImportError:
        if ep.startswith('gym.envs.atari'):
            return True
    try:
        import Box2D
    except ImportError:
        if ep.startswith('gym.envs.box2d'):
            return True

    if ('GoEnv' in ep or 'HexEnv' in ep or
        (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong")
         and not spec.id.startswith("Seaquest"))):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False
Ejemplo n.º 3
0
    def _encode_image_frame(self, frame):
        if not self.encoder:
            self.encoder = ImageEncoder(self.path, frame.shape,
                                        self.frames_per_sec)
            self.metadata['encoder_version'] = self.encoder.version_info

        try:
            self.encoder.capture_frame(frame)
        except error.InvalidFrame as e:
            logger.warn(
                'Tried to pass invalid video frame, marking as broken: %s', e)
            self.broken = True
        else:
            self.empty = False
Ejemplo n.º 4
0
    def close(self, timeout=None, terminate=False):
        """
        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to `close` times out. If `None`,
            the call to `close` never times out. If the call to `close` times
            out, then all processes are terminated.

        terminate : bool (default: `False`)
            If `True`, then the `close` operation is forced and all processes
            are terminated.
        """
        if self.closed:
            return

        if self.viewer is not None:
            self.viewer.close()

        timeout = 0 if terminate else timeout
        try:
            if self._state != AsyncState.DEFAULT:
                logger.warn('Calling `close` while waiting for a pending '
                            'call to `{0}` to complete.'.format(
                                self._state.value))
                function = getattr(self, '{0}_wait'.format(self._state.value))
                function(timeout)
        except mp.TimeoutError:
            terminate = True

        if terminate:
            for process in self.processes:
                if process.is_alive():
                    process.terminate()
        else:
            for pipe in self.parent_pipes:
                if (pipe is not None) and (not pipe.closed):
                    pipe.send(('close', None))
            for pipe in self.parent_pipes:
                if (pipe is not None) and (not pipe.closed):
                    pipe.recv()

        for pipe in self.parent_pipes:
            if pipe is not None:
                pipe.close()
        for process in self.processes:
            process.join()

        self.closed = True
Ejemplo n.º 5
0
def patch_deprecated_methods(env):
    """
    Methods renamed from '_method' to 'method', render() no longer has 'close' parameter, close is a separate method.
    For backward compatibility, this makes it possible to work with unmodified environments.
    """
    global warn_once
    if warn_once:
        logger.warn(
            "Environment '%s' has deprecated methods '_step' and '_reset' rather than 'step' and 'reset'. Compatibility code invoked. Set _gym_disable_underscore_compat = True to disable this behavior."
            % str(type(env)))
        warn_once = False
    env.reset = env._reset
    env.step = env._step
    env.seed = env._seed

    def render(mode):
        return env._render(mode, close=False)

    def close():
        env._render("human", close=True)

    env.render = render
    env.close = close
Ejemplo n.º 6
0
    def capture_frame(self):
        """Render the given `env` and add the resulting frame to the video."""
        if not self.functional: return
        logger.debug('Capturing video frame: path=%s', self.path)

        render_mode = 'ansi' if self.ansi_mode else 'rgb_array'
        frame = self.env.render(mode=render_mode)

        if frame is None:
            if self._async:
                return
            else:
                # Indicates a bug in the environment: don't want to raise
                # an error here.
                logger.warn(
                    'Env returned None on render(). Disabling further rendering for video recorder by marking as disabled: path=%s metadata_path=%s',
                    self.path, self.metadata_path)
                self.broken = True
        else:
            self.last_frame = frame
            if self.ansi_mode:
                self._encode_ansi_frame(frame)
            else:
                self._encode_image_frame(frame)
Ejemplo n.º 7
0
def test_env_semantics(spec):
    logger.warn(
        "Skipping this test. Existing hashes were generated in a bad way")
    return
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)

    if spec.id not in rollout_dict:
        if not spec.nondeterministic:
            logger.warn(
                "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs"
                .format(spec.id))
        return

    logger.info("Testing rollout for {} environment...".format(spec.id))

    observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(
        spec)

    errors = []
    if rollout_dict[spec.id]['observations'] != observations_now:
        errors.append(
            'Observations not equal for {} -- expected {} but got {}'.format(
                spec.id, rollout_dict[spec.id]['observations'],
                observations_now))
    if rollout_dict[spec.id]['actions'] != actions_now:
        errors.append(
            'Actions not equal for {} -- expected {} but got {}'.format(
                spec.id, rollout_dict[spec.id]['actions'], actions_now))
    if rollout_dict[spec.id]['rewards'] != rewards_now:
        errors.append(
            'Rewards not equal for {} -- expected {} but got {}'.format(
                spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
    if rollout_dict[spec.id]['dones'] != dones_now:
        errors.append(
            'Dones not equal for {} -- expected {} but got {}'.format(
                spec.id, rollout_dict[spec.id]['dones'], dones_now))
    if len(errors):
        for error in errors:
            logger.warn(error)
        raise ValueError(errors)
Ejemplo n.º 8
0
    def __init__(self,
                 env_fns,
                 observation_space=None,
                 action_space=None,
                 shared_memory=True,
                 copy=True,
                 context=None,
                 daemon=True,
                 worker=None):
        try:
            ctx = mp.get_context(context)
        except AttributeError:
            logger.warn('Context switching for `multiprocessing` is not '
                        'available in Python 2. Using the default context.')
            ctx = mp
        self.env_fns = env_fns
        self.shared_memory = shared_memory
        self.copy = copy

        if (observation_space is None) or (action_space is None):
            dummy_env = env_fns[0]()
            observation_space = observation_space or dummy_env.observation_space
            action_space = action_space or dummy_env.action_space
            dummy_env.close()
            del dummy_env
        super(AsyncVectorEnv,
              self).__init__(num_envs=len(env_fns),
                             observation_space=observation_space,
                             action_space=action_space)

        if self.shared_memory:
            _obs_buffer = create_shared_memory(self.single_observation_space,
                                               n=self.num_envs,
                                               ctx=ctx)
            self.observations = read_from_shared_memory(
                _obs_buffer, self.single_observation_space, n=self.num_envs)
        else:
            _obs_buffer = None
            self.observations = create_empty_array(
                self.single_observation_space, n=self.num_envs, fn=np.zeros)

        self.parent_pipes, self.processes = [], []
        self.error_queue = ctx.Queue()
        target = _worker_shared_memory if self.shared_memory else _worker
        target = worker or target
        with clear_mpi_env_vars():
            for idx, env_fn in enumerate(self.env_fns):
                parent_pipe, child_pipe = ctx.Pipe()
                process = ctx.Process(
                    target=target,
                    name='Worker<{0}>-{1}'.format(type(self).__name__, idx),
                    args=(idx, CloudpickleWrapper(env_fn), child_pipe,
                          parent_pipe, _obs_buffer, self.error_queue))

                self.parent_pipes.append(parent_pipe)
                self.processes.append(process)

                process.daemon = daemon
                process.start()
                child_pipe.close()

        self._state = AsyncState.DEFAULT
        self._check_observation_spaces()