Exemplo n.º 1
0
def run(agent: base.Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        verbose: bool = False) -> None:
  """Runs an agent on an environment.

  Note that for bsuite environments, logging is handled internally.

  Args:
    agent: The agent to train and evaluate.
    environment: The environment to train on.
    num_episodes: Number of episodes to train for.
    verbose: Whether to also log to terminal.
  """

  if verbose:
    environment = terminal_logging.wrap_environment(
        environment, log_every=True)  # pytype: disable=wrong-arg-types

  for _ in range(num_episodes):
    # Run an episode.
    timestep = environment.reset()
    while not timestep.last():
      # Generate an action from the agent's policy.
      action = agent.select_action(timestep)

      # Step the environment.
      new_timestep = environment.step(action)

      # Tell the agent about what just happened.
      agent.update(timestep, action, new_timestep)

      # Book-keeping.
      timestep = new_timestep
Exemplo n.º 2
0
def run(
    agent: base.Agent,
    env: dm_env.Environment,
    num_episodes: int,
    eval_mode: bool = False,
) -> base.Agent:
    wandb.init(project="dqn")
    logging.info(
        "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}"
        .format("evaluating" if eval_mode else "training", agent, env,
                num_episodes))
    for episode in range(num_episodes):
        print(
            "Starting episode number {}/{}\t\t\t".format(
                episode, num_episodes - 1),
            end="\r",
        )
        wandb.log({"Episode": episode})
        # initialise environment
        timestep = env.reset()
        while not timestep.last():
            # policy
            action = agent.select_action(timestep)
            # step environment
            new_timestep = env.step(tuple(action))
            wandb.log({"Reward": new_timestep.reward})
            # update
            if not eval_mode:
                loss = agent.update(timestep, action, new_timestep)
                if loss is not None:
                    wandb.log({"Bellman MSE": float(loss)})
                wandb.log({"Iteration": agent.iteration})
            # prepare next
            timestep = new_timestep
    return agent
Exemplo n.º 3
0
def run_episode(agent: Agent,
                env: Environment,
                action_repeat: int = 1,
                update: bool = False):
    start_time = time()
    episode_steps = 0
    episode_return = 0
    timestep = env.reset()
    agent.observe_first(timestep)
    while not timestep.last():
        action = agent.select_action(timestep.observation)
        for _ in range(action_repeat):
            timestep = env.step(action)
            agent.observe(action, next_timestep=timestep)
            episode_steps += 1
            episode_return += timestep.reward
            if update:
                agent.update()
            if timestep.last():
                break

    steps_per_second = episode_steps / (time() - start_time)
    result = {
        'episode_length': episode_steps,
        'episode_return': episode_return,
        'steps_per_second': steps_per_second,
    }
    return result
Exemplo n.º 4
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the
      observation, action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = (observation, action, reward, discount, observation)

    key = np.array(0, np.uint64)
    probability = np.array(1.0, np.float64)
    table_size = np.array(1, np.int64)
    priority = np.array(1.0, np.float64)
    info = reverb.SampleInfo(key=key,
                             probability=probability,
                             table_size=table_size,
                             priority=priority)
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Exemplo n.º 5
0
def run(agent: Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        results_dir: str = 'res/default.pkl') -> None:
    '''
    Runs an agent on an enviroment.

    Args:
        agent: The agent to train and evaluate.
        environment: The environment to train on.
        num_episodes: Number of episodes to train for.
        verbose: Whether to also log to terminal.
    '''

    for episode in range(num_episodes):
        #Run an episode.
        timestep = environment.reset()

        while not timestep.last():
            action = agent.select_action(timestep)
            print(action)
            new_timestep = environment.step(action)

            # Pass the (s, a, r, s')info to the agent.
            agent.update(timestep, action, new_timestep)

            # update timestep
            timestep = new_timestep

        if (episode + 1) % 100 == 0:
            print("Episode %d success." % (episode + 1))

        if True:
            torch.save(getattr(agent, '_network'), results_dir)
Exemplo n.º 6
0
def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec:
  """Returns an `EnvironmentSpec` describing values used by an environment."""
  return EnvironmentSpec(
      observations=environment.observation_spec(),
      actions=environment.action_spec(),
      rewards=environment.reward_spec(),
      discounts=environment.discount_spec())
Exemplo n.º 7
0
def run_loop(
    agent: Agent,
    environment: dm_env.Environment,
    max_steps_per_episode: int = 0,
    yield_before_reset: bool = False,
) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent,
                    Optional[Action]]]:
    """Repeatedly alternates step calls on environment and agent.

    At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been
    seen in the current episode. `t` resets to `0` for the next episode.

    Args:
      agent: Agent to be run, has methods `step(timestep)` and `reset()`.
      environment: Environment to run, has methods `step(action)` and `reset()`.
      max_steps_per_episode: If positive, when time t reaches this value within an
        episode, the episode is truncated.
      yield_before_reset: Whether to additionally yield `(environment, None,
        agent, None)` before the agent and environment is reset at the start of
        each episode.

    Yields:
      Tuple `(environment, timestep_t, agent, a_t)` where
      `a_t = agent.step(timestep_t)`.
    """
    while True:  # For each episode.
        if yield_before_reset:
            yield environment, None, agent, None,

        t = 0
        agent.reset()
        timestep_t = environment.reset()  # timestep_0.

        while True:  # For each step in the current episode.
            a_t = agent.step(timestep_t)
            yield environment, timestep_t, agent, a_t

            # Update t after one environment step and agent step and relabel.
            t += 1
            a_tm1 = a_t
            timestep_t = environment.step(a_tm1)

            if max_steps_per_episode > 0 and t >= max_steps_per_episode:
                assert t == max_steps_per_episode
                timestep_t = timestep_t._replace(
                    step_type=dm_env.StepType.LAST)

            if timestep_t.last():
                unused_a_t = agent.step(
                    timestep_t)  # Extra agent step, action ignored.
                yield environment, timestep_t, agent, None
                break
Exemplo n.º 8
0
    def __init__(self,
                 environment: dm_env.Environment,
                 *,
                 frame_rate: Optional[int] = None,
                 camera_id: Optional[int] = 0,
                 height: int = 240,
                 width: int = 320,
                 playback_speed: float = 1.,
                 **kwargs):

        # Check that we have a mujoco environment (or a wrapper thereof).
        if not hasattr(environment, '_physics'):
            raise ValueError(
                'MujocoVideoWrapper expects an environment which '
                'exposes a _physics attribute corresponding to a MuJoCo '
                'physics engine')

        # Compute frame rate if not set.
        if frame_rate is None:
            frame_rate = int(
                round(playback_speed / environment.control_timestep()))

        super().__init__(environment, frame_rate=frame_rate, **kwargs)
        self._camera_id = camera_id
        self._height = height
        self._width = width
Exemplo n.º 9
0
    def __init__(self,
                 environment: dm_env.Environment,
                 name_filter: Optional[Sequence[str]] = None):
        """Initializes a new ConcatObservationWrapper.

    Args:
      environment: Environment to wrap.
      name_filter: Sequence of observation names to keep. None keeps them all.
    """
        super().__init__(environment)
        observation_spec = environment.observation_spec()
        if name_filter is None:
            name_filter = list(observation_spec.keys())
        self._obs_names = [
            x for x in name_filter if x in observation_spec.keys()
        ]

        dummy_obs = _zeros_like(observation_spec)
        dummy_obs = self._convert_observation(dummy_obs)
        self._observation_spec = dm_env.specs.BoundedArray(
            shape=dummy_obs.shape,
            dtype=dummy_obs.dtype,
            minimum=-np.inf,
            maximum=np.inf,
            name='state')
Exemplo n.º 10
0
    def __init__(
            self,
            env: Environment,
            alpha: float,
            lamda: float,
            n_features: int,
            logger: Logger = NullLogger(),
    ):
        super().__init__()
        self.alpha = alpha
        self.lamda = lamda
        self.w = jnp.zeros(n_features + 1)
        self.z = jnp.zeros(n_features + 1)
        self.logger = logger
        self.policy = EGreedy(env.action_spec, 0.1)

        # temporary value estimate for the starting state
        self._actions = jnp.arange(env.action_spec().num_actions)
        self._q_old = jnp.zeros(env.action_spec().num_values)
Exemplo n.º 11
0
 def _make_ma_environment_spec(
         self,
         environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]:
     """Returns an `EnvironmentSpec` describing values used by
     an environment for each agent."""
     specs = {}
     observation_specs = environment.observation_spec()
     action_specs = environment.action_spec()
     reward_specs = environment.reward_spec()
     discount_specs = environment.discount_spec()
     self.extra_specs = environment.extra_spec()
     for agent in environment.possible_agents:
         specs[agent] = EnvironmentSpec(
             observations=observation_specs[agent],
             actions=action_specs[agent],
             rewards=reward_specs[agent],
             discounts=discount_specs[agent],
         )
     return specs
Exemplo n.º 12
0
    def assert_env_reset(
        wrapped_env: dm_env.Environment,
        dm_env_timestep: dm_env.TimeStep,
        env_spec: EnvSpec,
    ) -> None:
        if env_spec.env_type == EnvType.Parallel:
            rewards_spec = wrapped_env.reward_spec()
            expected_rewards = {
                agent: convert_np_type(rewards_spec[agent].dtype, 0)
                for agent in wrapped_env.agents
            }

            discount_spec = wrapped_env.discount_spec()
            expected_discounts = {
                agent: convert_np_type(rewards_spec[agent].dtype, 1)
                for agent in wrapped_env.agents
            }

            Helpers.compare_dicts(
                dm_env_timestep.reward,
                expected_rewards,
            ), "Failed to reset reward."
            Helpers.compare_dicts(
                dm_env_timestep.discount,
                expected_discounts,
            ), "Failed to reset discount."

        elif env_spec.env_type == EnvType.Sequential:
            for agent in wrapped_env.agents:
                rewards_spec = wrapped_env.reward_spec()
                expected_reward = convert_np_type(rewards_spec[agent].dtype, 0)

                discount_spec = wrapped_env.discount_spec()
                expected_discount = convert_np_type(discount_spec[agent].dtype,
                                                    1)

                assert dm_env_timestep.reward == expected_reward and type(
                    dm_env_timestep.reward) == type(
                        expected_reward), "Failed to reset reward."
                assert dm_env_timestep.discount == expected_discount and type(
                    dm_env_timestep.discount) == type(
                        expected_discount), "Failed to reset discount."
Exemplo n.º 13
0
 def observe(
     self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: int
 ) -> dm_env.TimeStep:
     #  iterate over the number of steps
     for t in range(self.hparams.n_steps):
         #  get new MDP state
         new_timestep = env.step(action)
         #  store transition into the replay buffer
         self.memory.add(timestep, action, new_timestep, preprocess=self.preprocess)
         timestep = new_timestep
     return timestep
Exemplo n.º 14
0
    def __init__(
        self,
        environment: dm_env.Environment,
        additional_discount: float = 0.99,
        max_abs_reward: Optional[float] = 1.0,
        resize_shape: Optional[Tuple[int, int]] = (84, 84),
        num_action_repeats: int = 4,
        num_pooled_frames: int = 2,
        zero_discount_on_life_loss: bool = True,
        num_stacked_frames: int = 4,
        grayscaling: bool = True,
    ):
        rgb_spec, unused_lives_spec = environment.observation_spec()
        if rgb_spec.shape[2] != 3:
            raise ValueError(
                'This wrapper assumes interleaved pixel observations with shape '
                '(height, width, channels).')
        if int(environment.action_spec().minimum) != 0:
            raise ValueError('This wrapper assumes zero-indexed actions.')

        self._environment = environment
        self._processor = atari(
            additional_discount=additional_discount,
            max_abs_reward=max_abs_reward,
            resize_shape=resize_shape,
            num_action_repeats=num_action_repeats,
            num_pooled_frames=num_pooled_frames,
            zero_discount_on_life_loss=zero_discount_on_life_loss,
            num_stacked_frames=num_stacked_frames,
            grayscaling=grayscaling,
        )

        if grayscaling:
            self._observation_shape = resize_shape + (num_stacked_frames, )
            self._observation_spec_name = 'grayscale'
        else:
            self._observation_shape = resize_shape + (3, num_stacked_frames)
            self._observation_spec_name = 'RGB'

        self._reset_next_step = True
Exemplo n.º 15
0
def transition_iterator(
    environment: dm_env.Environment
) -> Callable[[int], Iterator[types.Transition]]:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    A callable that given a batch_size returns an iterator with demonstrations.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
Exemplo n.º 16
0
def actor(server: Connection, client: Connection, env: dm_env.Environment):
    def _step(env, a: int):
        timestep = env.step(a)
        if timestep.last():
            timestep = env.reset()
        return timestep

    def _step_async(env, a: int, buffer: mp.Queue):
        timestep = env.step(a)
        buffer.put(timestep)
        print(buffer.qsize())
        if timestep.last():
            timestep = env.reset()
        return

    #  close copy of server connection from client process
    #  see: https://stackoverflow.com/q/8594909/6655465
    server.close()
    #  switch case command
    try:
        while True:
            cmd, data = client.recv()
            if cmd == "step":
                client.send(_step(env, data))
            elif cmd == "step_async":
                client.send(_step_async(env, *data))
            elif cmd == "reset":
                client.send(env.reset())
            elif cmd == "render":
                client.send(env.render(data))
            elif cmd == "close":
                client.send(env.close())
                break
            else:
                raise NotImplementedError("Command {} is not implemented".format(cmd))
    except KeyboardInterrupt:
        logging.info("SubprocVecEnv actor: got KeyboardInterrupt")
    finally:
        env.close()
Exemplo n.º 17
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Exemplo n.º 18
0
 def run(
     self,
     env: dm_env.Environment,
     num_episodes: int,
     eval: bool = False,
 ) -> Loss:
     logging.info(
         "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}"
         .format("evaluating" if eval else "training", self, env,
                 num_episodes))
     logging.info(
         "The hyperparameters for the current experiment are {}".format(
             self.hparams._asdict()))
     for episode in range(num_episodes):
         print(
             "Episode {}/{}\t\t\t".format(episode, num_episodes - 1),
             end="\r",
         )
         #  initialise environment
         episode_reward = 0.0
         timestep = env.reset()
         while not timestep.last():
             #  apply policy
             action = self.policy(timestep)
             #  observe new state
             new_timestep = self.observe(env, timestep, action)
             episode_reward += new_timestep.reward
             print(
                 "Episode reward {}\t\t".format(episode_reward),
                 end="\r",
             )
             #  update policy
             loss = None
             if not eval:
                 loss = self.update(timestep, action, new_timestep)
             #  log update
             if self.logging:
                 self.log(timestep, action, new_timestep, loss)
             # prepare next iteration
             timestep = new_timestep
     return loss
Exemplo n.º 19
0
    def __init__(
        self,
        agent: agent_lib.Agent,
        env: dm_env.Environment,
        unroll_length: int,
        learner: learner_lib.Learner,
        rng_seed: int = 42,
        logger=None,
    ):
        self._agent = agent
        self._env = env
        self._unroll_length = unroll_length
        self._learner = learner
        self._timestep = env.reset()
        self._agent_state = agent.initial_state(None)
        self._traj = []
        self._rng_key = jax.random.PRNGKey(rng_seed)

        if logger is None:
            logger = util.NullLogger()
        self._logger = logger

        self._episode_return = 0.
Exemplo n.º 20
0
    def __init__(self,
                 environment: dm_env.Environment,
                 *,
                 max_abs_reward: Optional[float] = None,
                 scale_dims: Optional[Tuple[int, int]] = (84, 84),
                 action_repeats: int = 4,
                 pooled_frames: int = 2,
                 zero_discount_on_life_loss: bool = False,
                 expose_lives_observation: bool = False,
                 num_stacked_frames: int = 4,
                 max_episode_len: Optional[int] = None,
                 to_float: bool = False,
                 grayscaling: bool = True):
        """Initializes a new AtariWrapper.

    Args:
      environment: An Atari environment.
      max_abs_reward: Maximum absolute reward value before clipping is applied.
        If set to `None` (default), no clipping is applied.
      scale_dims: Image size for the rescaling step after grayscaling, given as
        `(height, width)`. Set to `None` to disable resizing.
      action_repeats: Number of times to step wrapped environment for each given
        action.
      pooled_frames: Number of observations to pool over. Set to 1 to disable
        frame pooling.
      zero_discount_on_life_loss: If `True`, sets the discount to zero when the
        number of lives decreases in in Atari environment.
      expose_lives_observation: If `False`, the `lives` part of the observation
        is discarded, otherwise it is kept as part of an observation tuple. This
        does not affect the `zero_discount_on_life_loss` feature. When enabled,
        the observation consists of a single pixel array, otherwise it is a
        tuple (pixel_array, lives).
      num_stacked_frames: Number of recent (pooled) observations to stack into
        the returned observation.
      max_episode_len: Number of frames before truncating episode. By default,
        there is no maximum length.
      to_float: If `True`, rescales RGB observations to floats in [0, 1].
      grayscaling: If `True` returns a grayscale version of the observations. In
        this case, the observation is 3D (H, W, num_stacked_frames). If `False`
        the observations are RGB and have shape (H, W, C, num_stacked_frames).

    Raises:
      ValueError: For various invalid inputs.
    """
        if not 1 <= pooled_frames <= action_repeats:
            raise ValueError("pooled_frames ({}) must be between 1 and "
                             "action_repeats ({}) inclusive".format(
                                 pooled_frames, action_repeats))

        if zero_discount_on_life_loss:
            super().__init__(_ZeroDiscountOnLifeLoss(environment))
        else:
            super().__init__(environment)

        if not max_episode_len:
            max_episode_len = np.inf

        self._frame_stacker = frame_stacking.FrameStacker(
            num_frames=num_stacked_frames)
        self._action_repeats = action_repeats
        self._pooled_frames = pooled_frames
        self._scale_dims = scale_dims
        self._max_abs_reward = max_abs_reward or np.inf
        self._to_float = to_float
        self._expose_lives_observation = expose_lives_observation

        if scale_dims:
            self._height, self._width = scale_dims
        else:
            spec = environment.observation_spec()
            self._height, self._width = spec[RGB_INDEX].shape[:2]

        self._episode_len = 0
        self._max_episode_len = max_episode_len
        self._reset_next_step = True

        self._grayscaling = grayscaling

        # Based on underlying observation spec, decide whether lives are to be
        # included in output observations.
        observation_spec = self._environment.observation_spec()
        spec_names = [spec.name for spec in observation_spec]
        if "lives" in spec_names and spec_names.index("lives") != 1:
            raise ValueError(
                "`lives` observation needs to have index 1 in Atari.")

        self._observation_spec = self._init_observation_spec()

        self._raw_observation = None
Exemplo n.º 21
0
 def __init__(self, environment: dm_env.Environment, clip: bool = False):
     super().__init__(environment)
     self._action_spec = environment.action_spec()
     self._clip = clip
Exemplo n.º 22
0
 def __init__(self, environment: dm_env.Environment):
     self._environment = environment
     if int(environment.action_spec()[0].minimum) != 0:
         raise ValueError(
             'This wrapper assumes zero-indexed actions. Use the Atari setting '
             'zero_indexed_actions=\"true\" to get actions in this format.')