def testCreateAgentWithDefaults(self):
   # Verifies that we can create and train an agent with the default values.
   agent = rainbow_agent.JaxRainbowAgent(num_actions=4)
   observation = onp.ones([84, 84, 1])
   agent.begin_episode(observation)
   agent.step(reward=1, observation=observation)
   agent.end_episode(reward=1)
Beispiel #2
0
def create_incoherent_agent(sess,
                            environment,
                            agent_name='incoherent_dqn',
                            summary_writer=None,
                            debug_mode=False):
    """Creates an incoherent agent.

  Args:
    sess: TF session, unused since we are in JAX.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, unused.

  Returns:
    An active and passive agent.
  """
    assert agent_name is not None
    del sess
    del debug_mode
    if agent_name == 'dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_dqn':
        return incoherent_dqn_agent.IncoherentDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_implicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            coherence_weight=0.0,
            tau=0.03,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            tau=0.03,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
Beispiel #3
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
 def testStoreTransitionWithPrioritizedSamplingy(self):
   agent = rainbow_agent.JaxRainbowAgent(
       num_actions=4, replay_scheme='prioritized')
   dummy_frame = onp.zeros((84, 84))
   # Adding transitions with default, 10., default priorities.
   agent._store_transition(dummy_frame, 0, 0, False)
   agent._store_transition(dummy_frame, 0, 0, False, priority=10.)
   agent._store_transition(dummy_frame, 0, 0, False)
   returned_priorities = agent._replay.get_priority(
       onp.arange(self.stack_size - 1, self.stack_size + 2, dtype=onp.int32))
   expected_priorities = [1., 10., 10.]
   onp.array_equal(returned_priorities, expected_priorities)
    def _create_test_agent(self):
        # This dummy network allows us to deterministically anticipate that
        # action 0 will be selected by an argmax.

        # In Rainbow we are dealing with a distribution over Q-values,
        # which are represented as num_atoms bins, ranging from -vmax to vmax.
        # The output layer will have num_actions * num_atoms elements,
        # so each group of num_atoms weights represent the logits for a
        # particular action. By setting 1s everywhere, except for the first
        # num_atoms (representing the logits for the first action), which are
        # set to onp.arange(num_atoms), we are ensuring that the first action
        # places higher weight on higher Q-values; this results in the first
        # action being chosen.
        class MockRainbowNetwork(linen.Module):
            """Custom Jax network used in tests."""
            num_actions: int
            num_atoms: int
            inputs_preprocessed: bool = False

            @linen.compact
            def __call__(self, x, support):
                def custom_init(key, shape, dtype=jnp.float32):
                    del key
                    to_pick_first_action = onp.ones(shape, dtype)
                    to_pick_first_action[:, :self.num_atoms] = onp.arange(
                        1, self.num_atoms + 1)
                    return to_pick_first_action

                x = x.astype(jnp.float32)
                x = x.reshape((-1))  # flatten
                x = linen.Dense(features=self.num_actions * self.num_atoms,
                                kernel_init=custom_init,
                                bias_init=linen.initializers.ones)(x)
                logits = x.reshape((self.num_actions, self.num_atoms))
                probabilities = linen.softmax(logits)
                qs = jnp.sum(support * probabilities, axis=1)
                return atari_lib.RainbowNetworkType(qs, logits, probabilities)

        agent = rainbow_agent.JaxRainbowAgent(
            network=MockRainbowNetwork,
            num_actions=self._num_actions,
            num_atoms=self._num_atoms,
            vmax=self._vmax,
            min_replay_history=self._min_replay_history,
            epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
            epsilon_eval=0.0,
            epsilon_decay_period=self._epsilon_decay_period)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        return agent