def __init__(self, network: hk.Transformed, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" loss_fn = losses.PrioritizedDoubleQLearning( discount=discount, importance_sampling_exponent=importance_sampling_exponent, max_abs_reward=max_abs_reward, huber_loss_parameter=huber_loss_parameter, ) super().__init__( network=network, obs_spec=obs_spec, loss_fn=loss_fn, optimizer=optimizer, data_iterator=iterator, target_update_period=target_update_period, rng=rng, replay_client=replay_client, counter=counter, logger=logger, )
def __init__(self, network: networks_lib.FeedForwardNetwork, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" loss_fn = losses.PrioritizedDoubleQLearning( discount=discount, importance_sampling_exponent=importance_sampling_exponent, max_abs_reward=max_abs_reward, huber_loss_parameter=huber_loss_parameter, ) super().__init__( network=network, loss_fn=loss_fn, optimizer=optimizer, data_iterator=iterator, target_update_period=target_update_period, random_key=random_key, replay_client=replay_client, counter=counter, logger=logger, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: dqn_config.DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). loss_fn = losses.PrioritizedDoubleQLearning( discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, ) learner = learning_lib.SGDLearner( network=network, loss_fn=loss_fn, data_iterator=reverb_replay.data_iterator, optimizer=optimizer, target_update_period=config.target_update_period, random_key=key_learner, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. assert config.epsilon is not Sequence def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample(key, action_values) actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) variable_client = variable_utils.VariableClient(learner, '') actor = actors.GenericActor( actor_core, key_actor, variable_client, reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def __init__(self, network: networks_lib.FeedForwardNetwork, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, stochastic_network: bool = False, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: Optional[reverb.Client] = None, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initializes the learner.""" loss_fn = losses.PrioritizedDoubleQLearning( discount=discount, importance_sampling_exponent=importance_sampling_exponent, max_abs_reward=max_abs_reward, huber_loss_parameter=huber_loss_parameter, stochastic_network=stochastic_network, ) super().__init__( network=network, loss_fn=loss_fn, optimizer=optimizer, data_iterator=iterator, target_update_period=target_update_period, random_key=random_key, replay_client=replay_client, replay_table_name=replay_table_name, counter=counter, logger=logger, num_sgd_steps_per_step=num_sgd_steps_per_step, )