Ejemplo n.º 1
0
 def __init__(self,
              network: hk.Transformed,
              obs_spec: specs.Array,
              discount: float,
              importance_sampling_exponent: float,
              target_update_period: int,
              iterator: Iterator[reverb.ReplaySample],
              optimizer: optax.GradientTransformation,
              rng: hk.PRNGSequence,
              max_abs_reward: float = 1.,
              huber_loss_parameter: float = 1.,
              replay_client: reverb.Client = None,
              counter: counting.Counter = None,
              logger: loggers.Logger = None):
     """Initializes the learner."""
     loss_fn = losses.PrioritizedDoubleQLearning(
         discount=discount,
         importance_sampling_exponent=importance_sampling_exponent,
         max_abs_reward=max_abs_reward,
         huber_loss_parameter=huber_loss_parameter,
     )
     super().__init__(
         network=network,
         obs_spec=obs_spec,
         loss_fn=loss_fn,
         optimizer=optimizer,
         data_iterator=iterator,
         target_update_period=target_update_period,
         rng=rng,
         replay_client=replay_client,
         counter=counter,
         logger=logger,
     )
Ejemplo n.º 2
0
 def __init__(self,
              network: networks_lib.FeedForwardNetwork,
              discount: float,
              importance_sampling_exponent: float,
              target_update_period: int,
              iterator: Iterator[reverb.ReplaySample],
              optimizer: optax.GradientTransformation,
              random_key: networks_lib.PRNGKey,
              max_abs_reward: float = 1.,
              huber_loss_parameter: float = 1.,
              replay_client: reverb.Client = None,
              counter: counting.Counter = None,
              logger: loggers.Logger = None):
   """Initializes the learner."""
   loss_fn = losses.PrioritizedDoubleQLearning(
       discount=discount,
       importance_sampling_exponent=importance_sampling_exponent,
       max_abs_reward=max_abs_reward,
       huber_loss_parameter=huber_loss_parameter,
   )
   super().__init__(
       network=network,
       loss_fn=loss_fn,
       optimizer=optimizer,
       data_iterator=iterator,
       target_update_period=target_update_period,
       random_key=random_key,
       replay_client=replay_client,
       counter=counter,
       logger=logger,
   )
Ejemplo n.º 3
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: networks_lib.FeedForwardNetwork,
      config: dqn_config.DQNConfig,
  ):
    """Initialize the agent."""
    # Data is communicated via reverb replay.
    reverb_replay = replay.make_reverb_prioritized_nstep_replay(
        environment_spec=environment_spec,
        n_step=config.n_step,
        batch_size=config.batch_size,
        max_replay_size=config.max_replay_size,
        min_replay_size=config.min_replay_size,
        priority_exponent=config.priority_exponent,
        discount=config.discount,
    )
    self._server = reverb_replay.server

    optimizer = optax.chain(
        optax.clip_by_global_norm(config.max_gradient_norm),
        optax.adam(config.learning_rate),
    )
    key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed))
    # The learner updates the parameters (and initializes them).
    loss_fn = losses.PrioritizedDoubleQLearning(
        discount=config.discount,
        importance_sampling_exponent=config.importance_sampling_exponent,
    )
    learner = learning_lib.SGDLearner(
        network=network,
        loss_fn=loss_fn,
        data_iterator=reverb_replay.data_iterator,
        optimizer=optimizer,
        target_update_period=config.target_update_period,
        random_key=key_learner,
        replay_client=reverb_replay.client,
    )

    # The actor selects actions according to the policy.
    assert config.epsilon is not Sequence
    def policy(params: networks_lib.Params, key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
      action_values = network.apply(params, observation)
      return rlax.epsilon_greedy(config.epsilon).sample(key, action_values)
    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
    variable_client = variable_utils.VariableClient(learner, '')
    actor = actors.GenericActor(
        actor_core, key_actor, variable_client, reverb_replay.adder)

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(config.batch_size, config.min_replay_size),
        observations_per_step=config.batch_size / config.samples_per_insert,
    )
Ejemplo n.º 4
0
 def __init__(self,
              network: networks_lib.FeedForwardNetwork,
              discount: float,
              importance_sampling_exponent: float,
              target_update_period: int,
              iterator: Iterator[reverb.ReplaySample],
              optimizer: optax.GradientTransformation,
              random_key: networks_lib.PRNGKey,
              stochastic_network: bool = False,
              max_abs_reward: float = 1.,
              huber_loss_parameter: float = 1.,
              replay_client: Optional[reverb.Client] = None,
              replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
              counter: Optional[counting.Counter] = None,
              logger: Optional[loggers.Logger] = None,
              num_sgd_steps_per_step: int = 1):
   """Initializes the learner."""
   loss_fn = losses.PrioritizedDoubleQLearning(
       discount=discount,
       importance_sampling_exponent=importance_sampling_exponent,
       max_abs_reward=max_abs_reward,
       huber_loss_parameter=huber_loss_parameter,
       stochastic_network=stochastic_network,
   )
   super().__init__(
       network=network,
       loss_fn=loss_fn,
       optimizer=optimizer,
       data_iterator=iterator,
       target_update_period=target_update_period,
       random_key=random_key,
       replay_client=replay_client,
       replay_table_name=replay_table_name,
       counter=counter,
       logger=logger,
       num_sgd_steps_per_step=num_sgd_steps_per_step,
   )