Exemplo n.º 1
0
 def testCreateAgentWithDefaults(self):
   # Verifies that we can create and train an agent with the default values.
   agent = quantile_agent.JaxQuantileAgent(num_actions=4)
   observation = onp.ones([84, 84, 1])
   agent.begin_episode(observation)
   agent.step(reward=1, observation=observation)
   agent.end_episode(reward=1)
Exemplo n.º 2
0
def create_incoherent_agent(sess,
                            environment,
                            agent_name='incoherent_dqn',
                            summary_writer=None,
                            debug_mode=False):
    """Creates an incoherent agent.

  Args:
    sess: TF session, unused since we are in JAX.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, unused.

  Returns:
    An active and passive agent.
  """
    assert agent_name is not None
    del sess
    del debug_mode
    if agent_name == 'dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_dqn':
        return incoherent_dqn_agent.IncoherentDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_implicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            coherence_weight=0.0,
            tau=0.03,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            tau=0.03,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
Exemplo n.º 3
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
Exemplo n.º 4
0
    def _create_test_agent(self):
        """Keras network for tests."""

        # This dummy network allows us to deterministically anticipate that
        # action 0 will be selected by an argmax.

        # In Quantile we are dealing with a distribution over Q-values,
        # which are represented as num_atoms quantiles.
        # The output layer will have num_actions * num_atoms elements,
        # so each group of num_atoms weights represent the value quantiles for
        # a particular action. By setting 1s everywhere, except for the first
        # num_atoms (representing the quantiles for the first action), which
        # are set to onp.arange(num_atoms), we are ensuring that the first action
        # has a higher expected Q-value; this results in the first
        # action being chosen.
        class MockQuantileNetwork(linen.Module):
            """Custom Jax network used in tests."""
            num_actions: int
            num_atoms: int
            inputs_preprocessed: bool = False

            @linen.compact
            def __call__(self, x):
                def custom_init(key, shape, dtype=jnp.float32):
                    del key
                    to_pick_first_action = onp.ones(shape, dtype)
                    to_pick_first_action[:, :self.num_atoms] = onp.arange(
                        1, self.num_atoms + 1)
                    return to_pick_first_action

                x = x.astype(jnp.float32)
                x = x.reshape((-1))  # flatten
                x = linen.Dense(features=self.num_actions * self.num_atoms,
                                kernel_init=custom_init,
                                bias_init=linen.initializers.ones)(x)
                logits = x.reshape((self.num_actions, self.num_atoms))
                probabilities = linen.softmax(logits)
                qs = jnp.mean(logits, axis=1)
                return atari_lib.RainbowNetworkType(qs, logits, probabilities)

        agent = quantile_agent.JaxQuantileAgent(
            network=MockQuantileNetwork,
            num_actions=self.num_actions,
            num_atoms=self._num_atoms,
            min_replay_history=self._min_replay_history,
            epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
            epsilon_eval=0.0,
            epsilon_decay_period=self._epsilon_decay_period)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        return agent