def testCreateAgentWithDefaults(self): # Verifies that we can create and train an agent with the default values. agent = dqn_agent.JaxDQNAgent(num_actions=4) observation = onp.ones([84, 84, 1]) agent.begin_episode(observation) agent.step(reward=1, observation=observation) agent.end_episode(reward=1)
def create_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an agent. Args: sess: A `tf.compat.v1.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def create_incoherent_agent(sess, environment, agent_name='incoherent_dqn', summary_writer=None, debug_mode=False): """Creates an incoherent agent. Args: sess: TF session, unused since we are in JAX. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, unused. Returns: An active and passive agent. """ assert agent_name is not None del sess del debug_mode if agent_name == 'dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'incoherent_dqn': return incoherent_dqn_agent.IncoherentDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'incoherent_implicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'mimplicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, coherence_weight=0.0, tau=0.03, summary_writer=summary_writer) elif agent_name == 'incoherent_mimplicit_quantile': return incoherent_implicit_quantile_agent.IncoherentImplicitQuantileAgent( num_actions=environment.action_space.n, tau=0.03, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def _create_test_agent(self, allow_partial_reload=False): # This dummy network allows us to deterministically anticipate that # action 0 will be selected by an argmax. class MockDQNNetwork(nn.Module): """The Jax network used in tests.""" num_actions: int inputs_preprocessed: bool = False @nn.compact def __call__(self, x): # This weights_initializer gives action 0 a higher weight, ensuring # that it gets picked by the argmax. def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.zeros(shape, dtype) to_pick_first_action[:, 0] = 1 return to_pick_first_action x = x.astype(jnp.float32) x = x.reshape((-1)) # flatten x = nn.Dense(features=self.num_actions, kernel_init=custom_init, bias_init=nn.initializers.ones)(x) return atari_lib.DQNNetworkType(x) agent = dqn_agent.JaxDQNAgent( network=MockDQNNetwork, observation_shape=self.observation_shape, observation_dtype=self.observation_dtype, stack_size=self.stack_size, num_actions=self.num_actions, min_replay_history=self.min_replay_history, epsilon_fn=lambda w, x, y, z: 0.0, # No exploration. update_period=self.update_period, target_update_period=self.target_update_period, epsilon_eval=0.0, # No exploration during evaluation. allow_partial_reload=allow_partial_reload) # This ensures non-random action choices (since epsilon_eval = 0.0) and # skips the train_step. agent.eval_mode = True return agent
JaxDQNAgent.update_period = 4 JaxDQNAgent.target_update_period = 100 JaxDQNAgent.epsilon_fn = @dqn_agent.identity_epsilon create_optimizer.name = 'adam' create_optimizer.learning_rate = 0.001 create_optimizer.eps = 3.125e-4 OutOfGraphReplayBuffer.replay_capacity = 50000 OutOfGraphReplayBuffer.batch_size = 128 """ gin.parse_config(cartpole_config, skip_unknown=False) dqn_agent = dqn_agent.JaxDQNAgent(num_actions=cartpole_env.action_space.n, observation_shape=(4, 1), observation_dtype=jnp.float64, stack_size=1, network=networks.CartpoleDQNNetwork) def learned_policy(s): return dqn_agent.step(0., s) # We pass in a dummy reward # We set our agent in `eval_mode` to avoid it from continuing to train while # interacting with the environment. dqn_agent.eval_mode = True #animate_agent(learned_policy, cartpole_env, num_frames=100) max_steps_per_episode = 200 # @param {type:'slider', min:10, max:1000} training_steps = 1000 # @param {type:'slider', min:10, max:5000} num_iterations = 30 # @param {type:'slider', min:10, max:200}