def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split( jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, random_key=key_learner, optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(key_actor), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( server_address=address, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, '') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 0.5, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0.05, learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=n_step, batch_size=batch_size, max_replay_size=max_replay_size, min_replay_size=min_replay_size, priority_exponent=priority_exponent, discount=discount, ) self._server = reverb_replay.server # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)