예제 #1
0
 def testCreateAgentWithDefaults(self):
     # Verifies that we can create and train an agent with the default values.
     agent = dqn_agent.JaxDQNAgent(num_actions=4)
     observation = onp.ones([84, 84, 1])
     agent.begin_episode(observation)
     agent.step(reward=1, observation=observation)
     agent.end_episode(reward=1)
예제 #2
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
예제 #3
0
def create_incoherent_agent(sess,
                            environment,
                            agent_name='incoherent_dqn',
                            summary_writer=None,
                            debug_mode=False):
    """Creates an incoherent agent.

  Args:
    sess: TF session, unused since we are in JAX.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, unused.

  Returns:
    An active and passive agent.
  """
    assert agent_name is not None
    del sess
    del debug_mode
    if agent_name == 'dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_dqn':
        return incoherent_dqn_agent.IncoherentDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_implicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            coherence_weight=0.0,
            tau=0.03,
            summary_writer=summary_writer)
    elif agent_name == 'incoherent_mimplicit_quantile':
        return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            tau=0.03,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
예제 #4
0
    def _create_test_agent(self, allow_partial_reload=False):

        # This dummy network allows us to deterministically anticipate that
        # action 0 will be selected by an argmax.
        class MockDQNNetwork(nn.Module):
            """The Jax network used in tests."""
            num_actions: int
            inputs_preprocessed: bool = False

            @nn.compact
            def __call__(self, x):
                # This weights_initializer gives action 0 a higher weight, ensuring
                # that it gets picked by the argmax.
                def custom_init(key, shape, dtype=jnp.float32):
                    del key
                    to_pick_first_action = onp.zeros(shape, dtype)
                    to_pick_first_action[:, 0] = 1
                    return to_pick_first_action

                x = x.astype(jnp.float32)
                x = x.reshape((-1))  # flatten
                x = nn.Dense(features=self.num_actions,
                             kernel_init=custom_init,
                             bias_init=nn.initializers.ones)(x)
                return atari_lib.DQNNetworkType(x)

        agent = dqn_agent.JaxDQNAgent(
            network=MockDQNNetwork,
            observation_shape=self.observation_shape,
            observation_dtype=self.observation_dtype,
            stack_size=self.stack_size,
            num_actions=self.num_actions,
            min_replay_history=self.min_replay_history,
            epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
            update_period=self.update_period,
            target_update_period=self.target_update_period,
            epsilon_eval=0.0,  # No exploration during evaluation.
            allow_partial_reload=allow_partial_reload)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        return agent
예제 #5
0
JaxDQNAgent.update_period = 4
JaxDQNAgent.target_update_period = 100
JaxDQNAgent.epsilon_fn = @dqn_agent.identity_epsilon

create_optimizer.name = 'adam'
create_optimizer.learning_rate = 0.001
create_optimizer.eps = 3.125e-4

OutOfGraphReplayBuffer.replay_capacity = 50000
OutOfGraphReplayBuffer.batch_size = 128
"""
gin.parse_config(cartpole_config, skip_unknown=False)

dqn_agent = dqn_agent.JaxDQNAgent(num_actions=cartpole_env.action_space.n,
                                  observation_shape=(4, 1),
                                  observation_dtype=jnp.float64,
                                  stack_size=1,
                                  network=networks.CartpoleDQNNetwork)

def learned_policy(s):
  return dqn_agent.step(0., s)  # We pass in a dummy reward

# We set our agent in `eval_mode` to avoid it from continuing to train while
# interacting with the environment.
dqn_agent.eval_mode = True
#animate_agent(learned_policy, cartpole_env, num_frames=100)

max_steps_per_episode = 200  # @param {type:'slider', min:10, max:1000}
training_steps = 1000  # @param {type:'slider', min:10, max:5000}
num_iterations = 30  # @param {type:'slider', min:10, max:200}