Exemple #1
0
  def __init__(
      self,
      preprocessor: processors.Processor,
      sample_network_input: IqnInputs,
      network: parts.Network,
      optimizer: optax.GradientTransformation,
      transition_accumulator: Any,
      replay: replay_lib.TransitionReplay,
      batch_size: int,
      exploration_epsilon: Callable[[int], float],
      min_replay_capacity_fraction: float,
      learn_period: int,
      target_network_update_period: int,
      huber_param: float,
      tau_samples_policy: int,
      tau_samples_s_tm1: int,
      tau_samples_s_t: int,
      rng_key: parts.PRNGKey,
  ):
    self._preprocessor = preprocessor
    self._replay = replay
    self._transition_accumulator = transition_accumulator
    self._batch_size = batch_size
    self._exploration_epsilon = exploration_epsilon
    self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
    self._learn_period = learn_period
    self._target_network_update_period = target_network_update_period

    # Initialize network parameters and optimizer.
    self._rng_key, network_rng_key = jax.random.split(rng_key)
    self._online_params = network.init(
        network_rng_key,
        jax.tree_map(lambda x: x[None, ...], sample_network_input))
    self._target_params = self._online_params
    self._opt_state = optimizer.init(self._online_params)

    # Other agent state: last action, frame count, etc.
    self._action = None
    self._frame_t = -1  # Current frame index.

    # Define jitted loss, update, and policy functions here instead of as
    # class methods, to emphasize that these are meant to be pure functions
    # and should not access the agent object's state via `self`.

    def loss_fn(online_params, target_params, transitions, rng_key):
      """Calculates loss given network parameters and transitions."""
      # Sample tau values for q_tm1, q_t_selector, q_t.
      batch_size = self._batch_size
      rng_key, *sample_keys = jax.random.split(rng_key, 4)
      tau_tm1 = _sample_tau(sample_keys[0], (batch_size, tau_samples_s_tm1))
      tau_t_selector = _sample_tau(sample_keys[1],
                                   (batch_size, tau_samples_policy))
      tau_t = _sample_tau(sample_keys[2], (batch_size, tau_samples_s_t))

      # Compute Q value distributions.
      _, *apply_keys = jax.random.split(rng_key, 4)
      dist_q_tm1 = network.apply(online_params, apply_keys[0],
                                 IqnInputs(transitions.s_tm1, tau_tm1)).q_dist
      dist_q_t_selector = network.apply(
          target_params, apply_keys[1],
          IqnInputs(transitions.s_t, tau_t_selector)).q_dist
      dist_q_target_t = network.apply(target_params, apply_keys[2],
                                      IqnInputs(transitions.s_t, tau_t)).q_dist
      losses = _batch_quantile_q_learning(
          dist_q_tm1,
          tau_tm1,
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          dist_q_t_selector,
          dist_q_target_t,
          huber_param,
      )
      assert losses.shape == (self._batch_size,)
      loss = jnp.mean(losses)
      return loss

    def update(rng_key, opt_state, online_params, target_params, transitions):
      """Computes learning update from batch of replay transitions."""
      rng_key, update_key = jax.random.split(rng_key)
      d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                          transitions, update_key)
      updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
      new_online_params = optax.apply_updates(online_params, updates)
      return rng_key, new_opt_state, new_online_params

    self._update = jax.jit(update)

    def select_action(rng_key, network_params, s_t, exploration_epsilon):
      """Samples action from eps-greedy policy wrt Q-values at given state."""
      rng_key, sample_key, apply_key, policy_key = jax.random.split(rng_key, 4)
      tau_t = _sample_tau(sample_key, (1, tau_samples_policy))
      q_t = network.apply(network_params, apply_key,
                          IqnInputs(s_t[None, ...], tau_t)).q_values[0]
      a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
      return rng_key, a_t

    self._select_action = jax.jit(select_action)
Exemple #2
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        optimizer: optax.GradientTransformation,
        transition_accumulator: Any,
        replay: replay_lib.TransitionReplay,
        shaping_function,
        mask_probability: float,
        num_heads: int,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        grad_error_bound: float,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._mask_probabilities = jnp.array(
            [mask_probability, 1 - mask_probability])
        self._num_heads = num_heads
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).multi_head_output
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).multi_head_output

            # batch by num_heads -> batch by num_heads by num_actions
            mask = jnp.einsum('ij,k->ijk', transitions.mask_t,
                              jnp.ones(q_tm1.shape[-1]))

            masked_q = jnp.multiply(mask, q_tm1)
            masked_q_target = jnp.multiply(mask, q_target_t)

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_target = jnp.reshape(q_target_t,
                                             (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
            )

            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )
            loss = jnp.mean(losses)
            return loss

        def update(rng_key, opt_state, online_params, target_params,
                   transitions):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                                transitions, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)

            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).random_head_q_value[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            return rng_key, a_t

        self._select_action = jax.jit(select_action)
Exemple #3
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        support: jnp.ndarray,
        optimizer: optax.GradientTransformation,
        transition_accumulator: Any,
        replay: replay_lib.TransitionReplay,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.
        self._statistics = {'state_value': np.nan}

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key = jax.random.split(rng_key, 3)
            logits_q_tm1 = network.apply(online_params, online_key,
                                         transitions.s_tm1).q_logits
            logits_target_q_t = network.apply(target_params, target_key,
                                              transitions.s_t).q_logits
            losses = _batch_categorical_q_learning(
                support,
                logits_q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                support,
                logits_target_q_t,
            )
            chex.assert_shape(losses, (self._batch_size, ))
            loss = jnp.mean(losses)
            return loss

        def update(rng_key, opt_state, online_params, target_params,
                   transitions):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                                transitions, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).q_values[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            v_t = jnp.max(q_t, axis=-1)
            return rng_key, a_t, v_t

        self._select_action = jax.jit(select_action)
Exemple #4
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        optimizer: optax.GradientTransformation,
        transition_accumulator: replay_lib.TransitionAccumulator,
        replay: replay_lib.PrioritizedTransitionReplay,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        grad_error_bound: float,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.
        self._max_seen_priority = 1.

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, weights,
                    rng_key):
            """Calculates loss given network parameters and transitions."""
            _, *apply_keys = jax.random.split(rng_key, 4)
            q_tm1 = network.apply(online_params, apply_keys[0],
                                  transitions.s_tm1).q_values
            q_t = network.apply(online_params, apply_keys[1],
                                transitions.s_t).q_values
            q_target_t = network.apply(target_params, apply_keys[2],
                                       transitions.s_t).q_values
            td_errors = _batch_double_q_learning(
                q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                q_target_t,
                q_t,
            )
            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size, ) == weights.shape
            # This is not the same as using a huber loss and multiplying by weights.
            loss = jnp.mean(losses * weights)
            return loss, td_errors

        def update(rng_key, opt_state, online_params, target_params,
                   transitions, weights):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params, td_errors = jax.grad(loss_fn, has_aux=True)(
                online_params, target_params, transitions, weights, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params, td_errors

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).q_values[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            return rng_key, a_t

        self._select_action = jax.jit(select_action)