def testCreateAgentWithDefaults(self): # Verifies that we can create and train an agent with the default values. agent = quantile_agent.JaxQuantileAgent(num_actions=4) observation = onp.ones([84, 84, 1]) agent.begin_episode(observation) agent.step(reward=1, observation=observation) agent.end_episode(reward=1)
def create_incoherent_agent(sess, environment, agent_name='incoherent_dqn', summary_writer=None, debug_mode=False): """Creates an incoherent agent. Args: sess: TF session, unused since we are in JAX. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, unused. Returns: An active and passive agent. """ assert agent_name is not None del sess del debug_mode if agent_name == 'dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'incoherent_dqn': return incoherent_dqn_agent.IncoherentDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'incoherent_implicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'mimplicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, coherence_weight=0.0, tau=0.03, summary_writer=summary_writer) elif agent_name == 'incoherent_mimplicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, tau=0.03, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def create_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an agent. Args: sess: A `tf.compat.v1.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def _create_test_agent(self): """Keras network for tests.""" # This dummy network allows us to deterministically anticipate that # action 0 will be selected by an argmax. # In Quantile we are dealing with a distribution over Q-values, # which are represented as num_atoms quantiles. # The output layer will have num_actions * num_atoms elements, # so each group of num_atoms weights represent the value quantiles for # a particular action. By setting 1s everywhere, except for the first # num_atoms (representing the quantiles for the first action), which # are set to onp.arange(num_atoms), we are ensuring that the first action # has a higher expected Q-value; this results in the first # action being chosen. class MockQuantileNetwork(linen.Module): """Custom Jax network used in tests.""" num_actions: int num_atoms: int inputs_preprocessed: bool = False @linen.compact def __call__(self, x): def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.ones(shape, dtype) to_pick_first_action[:, :self.num_atoms] = onp.arange( 1, self.num_atoms + 1) return to_pick_first_action x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten x = linen.Dense(features=self.num_actions * self.num_atoms, kernel_init=custom_init, bias_init=linen.initializers.ones)(x) logits = x.reshape((self.num_actions, self.num_atoms)) probabilities = linen.softmax(logits) qs = jnp.mean(logits, axis=1) return atari_lib.RainbowNetworkType(qs, logits, probabilities) agent = quantile_agent.JaxQuantileAgent( network=MockQuantileNetwork, num_actions=self.num_actions, num_atoms=self._num_atoms, min_replay_history=self._min_replay_history, epsilon_fn=lambda w, x, y, z: 0.0, # No exploration. epsilon_eval=0.0, epsilon_decay_period=self._epsilon_decay_period) # This ensures non-random action choices (since epsilon_eval = 0.0) and # skips the train_step. agent.eval_mode = True return agent