def testCreateAgentWithDefaults(self):
     # Verifies that we can create and train an agent with the default values.
     agent = implicit_quantile_agent.JaxImplicitQuantileAgent(num_actions=4)
     observation = onp.ones([84, 84, 1])
     agent.begin_episode(observation)
     agent.step(reward=1, observation=observation)
     agent.end_episode(reward=1)
Exemplo n.º 2
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
Exemplo n.º 3
0
def create_incoherent_agent(sess,
                            environment,
                            agent_name='incoherent_dqn',
                            summary_writer=None,
                            debug_mode=False):
    """Creates an incoherent agent.

  Args:
    sess: TF session, unused since we are in JAX.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, unused.

  Returns:
    An active and passive agent.
  """
    assert agent_name is not None
    del sess
    del debug_mode
    if agent_name == 'dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_dqn':
        return incoherent_dqn_agent.IncoherentDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_implicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            coherence_weight=0.0,
            tau=0.03,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            tau=0.03,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
    def _create_test_agent(self):
        # This dummy network allows us to deterministically anticipate that the
        # state-action-quantile outputs will be equal to sum of the
        # corresponding quantile inputs.
        # State/Quantile shapes will be k x 1, (N x batch_size) x 1,
        # or (N' x batch_size) x 1.

        class MockImplicitQuantileNetwork(linen.Module):
            """Custom Jax model used in tests."""
            num_actions: int
            quantile_embedding_dim: int
            inputs_preprocessed: bool = False

            @linen.compact
            def __call__(self, x, num_quantiles, rng):
                del rng
                x = x.reshape((-1))  # flatten
                state_net_tiled = jnp.tile(x, [num_quantiles, 1])
                x *= state_net_tiled
                quantile_values = linen.Dense(
                    features=self.num_actions,
                    kernel_init=linen.initializers.ones,
                    bias_init=linen.initializers.zeros)(x)
                quantiles = jnp.ones([num_quantiles, 1])
                return atari_lib.ImplicitQuantileNetworkType(
                    quantile_values, quantiles)

        agent = implicit_quantile_agent.JaxImplicitQuantileAgent(
            network=MockImplicitQuantileNetwork,
            num_actions=self._num_actions,
            kappa=1.0,
            num_tau_samples=2,
            num_tau_prime_samples=3,
            num_quantile_samples=4,
            epsilon_eval=0.0)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        return agent
Exemplo n.º 5
0
  def _create_test_agent(self):
    # This dummy network allows us to deterministically anticipate that the
    # state-action-quantile outputs will be equal to sum of the
    # corresponding quantile inputs.
    # State/Quantile shapes will be k x 1, (N x batch_size) x 1,
    # or (N' x batch_size) x 1.

    class MockImplicitQuantileNetwork(nn.Module):
      """Custom Jax model used in tests."""

      def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles,
                rng):
        del rng
        # This weights_initializer gives action 0 a higher weight, ensuring
        # that it gets picked by the argmax.
        batch_size = x.shape[0]
        x = x[None, :]
        x = x.astype(jnp.float32)
        x = x.reshape((x.shape[0], -1))  # flatten
        quantile_values = nn.Dense(x, features=num_actions,
                                   kernel_init=jax.nn.initializers.ones,
                                   bias_init=jax.nn.initializers.zeros)
        quantiles = jnp.ones([num_quantiles * batch_size, 1])
        return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)

    agent = implicit_quantile_agent.JaxImplicitQuantileAgent(
        network=MockImplicitQuantileNetwork,
        num_actions=self._num_actions,
        kappa=1.0,
        num_tau_samples=2,
        num_tau_prime_samples=3,
        num_quantile_samples=4,
        epsilon_eval=0.0)
    # This ensures non-random action choices (since epsilon_eval = 0.0) and
    # skips the train_step.
    agent.eval_mode = True
    return agent