Beispiel #1
0
 def testCreateAgentWithDefaults(self):
   # Verifies that we can create and train an agent with the default values.
   agent = full_rainbow_agent.JaxFullRainbowAgent(num_actions=4)
   observation = onp.ones([84, 84, 1])
   agent.begin_episode(observation)
   agent.step(reward=1, observation=observation)
   agent.end_episode(reward=1)
Beispiel #2
0
  def _create_test_agent(self):
    # This dummy network allows us to deterministically anticipate that
    # action 0 will be selected by an argmax.

    # In Rainbow we are dealing with a distribution over Q-values,
    # which are represented as num_atoms bins, ranging from -vmax to vmax.
    # The output layer will have num_actions * num_atoms elements,
    # so each group of num_atoms weights represent the logits for a
    # particular action. By setting 1s everywhere, except for the first
    # num_atoms (representing the logits for the first action), which are
    # set to onp.arange(num_atoms), we are ensuring that the first action
    # places higher weight on higher Q-values; this results in the first
    # action being chosen.
    class MockFullRainbowNetwork(nn.Module):
      """Custom Jax network used in tests."""
      num_actions: int
      num_atoms: int
      noisy: bool
      dueling: bool
      distributional: bool
      inputs_preprocessed: bool = False

      @nn.compact
      def __call__(self, x, support, eval_mode=False, key=None):

        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :self.num_atoms] = onp.arange(
              1, self.num_atoms + 1)
          return to_pick_first_action

        x = x.astype(jnp.float32)
        x = x.reshape((-1))  # flatten
        x = nn.Dense(
            features=self.num_actions * self.num_atoms,
            kernel_init=custom_init,
            bias_init=nn.initializers.ones)(
                x)
        logits = x.reshape((self.num_actions, self.num_atoms))
        if not self.distributional:
          qs = jnp.sum(logits, axis=-1)  # Sum over all the num_atoms
          return atari_lib.DQNNetworkType(qs)
        probabilities = nn.softmax(logits)
        qs = jnp.sum(support * probabilities, axis=1)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)

    agent = full_rainbow_agent.JaxFullRainbowAgent(
        network=MockFullRainbowNetwork,
        num_actions=self._num_actions,
        num_atoms=self._num_atoms,
        vmax=self._vmax,
        distributional=True,
        epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
    )
    # This ensures non-random action choices (since epsilon_eval = 0.0) and
    # skips the train_step.
    agent.eval_mode = True
    return agent
Beispiel #3
0
 def testStoreTransitionWithPrioritizedSampling(self):
   agent = full_rainbow_agent.JaxFullRainbowAgent(
       num_actions=4, replay_scheme='prioritized')
   dummy_frame = onp.zeros((84, 84))
   # Adding transitions with default, 10., default priorities.
   agent._store_transition(dummy_frame, 0, 0, False)
   agent._store_transition(dummy_frame, 0, 0, False, priority=10.)
   agent._store_transition(dummy_frame, 0, 0, False)
   returned_priorities = agent._replay.get_priority(
       onp.arange(self.stack_size - 1, self.stack_size + 2, dtype=onp.int32))
   expected_priorities = [1., 10., 10.]
   onp.array_equal(returned_priorities, expected_priorities)
Beispiel #4
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'full_rainbow':
        return full_rainbow_agent.JaxFullRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))