Example #1
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: PolicyValueNet,
        optimizer: optix.InitUpdate,
        rng: hk.PRNGSequence,
        sequence_length: int,
        discount: float,
        td_lambda: float,
    ):

        # Define loss function.
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss

        # Transform the loss into a pure function.
        loss_fn = hk.transform(loss).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     trajectory: sequence.Trajectory) -> TrainingState:
            """Does a step of SGD over a trajectory."""
            gradients = jax.grad(loss_fn)(state.params, trajectory)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)
            return TrainingState(params=new_params, opt_state=new_opt_state)

        # Initialize network parameters and optimiser state.
        init, forward = hk.transform(network)
        dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
        initial_params = init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = TrainingState(initial_params, initial_opt_state)
        self._forward = jax.jit(forward)
        self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
        self._sgd_step = sgd_step
        self._rng = rng
Example #2
0
    def __init__(self,
                 network: networks.QNetwork,
                 obs_spec: specs.Array,
                 discount: float,
                 importance_sampling_exponent: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optix.InitUpdate,
                 rng: hk.PRNGSequence,
                 max_abs_reward: float = 1.,
                 huber_loss_parameter: float = 1.,
                 replay_client: reverb.Client = None,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner."""

        # Transform network into a pure function.
        network = hk.transform(network)

        def loss(params: hk.Params, target_params: hk.Params,
                 sample: reverb.ReplaySample):
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            keys, probs = sample.info[:2]

            # Forward pass.
            q_tm1 = network.apply(params, o_tm1)
            q_t_value = network.apply(target_params, o_t)
            q_t_selector = network.apply(params, o_t)

            # Cast and clip rewards.
            d_t = (d_t * discount).astype(jnp.float32)
            r_t = jnp.clip(r_t, -max_abs_reward,
                           max_abs_reward).astype(jnp.float32)

            # Compute double Q-learning n-step TD-error.
            batch_error = jax.vmap(rlax.double_q_learning)
            td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                   q_t_selector)
            batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

            # Importance weighting.
            importance_weights = (1. / probs).astype(jnp.float32)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)

            # Reweight.
            mean_loss = jnp.mean(importance_weights * batch_loss)  # []

            priorities = jnp.abs(td_error).astype(jnp.float64)

            return mean_loss, (keys, priorities)

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            new_state = TrainingState(params=new_params,
                                      target_params=state.target_params,
                                      opt_state=new_opt_state,
                                      step=state.step + 1)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs

        def update_priorities(outputs: LearnerOutputs):
            for key, priority in zip(outputs.keys, outputs.priorities):
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates={key: priority})

        # Internalise agent components (replay buffer, networks, optimizer).
        self._replay_client = replay_client
        self._iterator = utils.prefetch(iterator)

        # Internalise the hyperparameters.
        self._target_update_period = target_update_period

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialise parameters and optimiser state.
        initial_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_target_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_opt_state = optimizer.init(initial_params)

        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)

        self._forward = jax.jit(network.apply)
        self._sgd_step = jax.jit(sgd_step)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Example #3
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: hk.Transformed,
        num_ensemble: int,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: optix.InitUpdate,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = 1,
    ):
        """Bootstrapped DQN with randomized prior functions."""

        # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Q-learning loss with added reward noise + half-in bootstrap."""
            o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            r_t += noise_scale * z_t
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(m_t * td_error**2)

        # Define update function for each member of ensemble..
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Does a step of SGD for the whole ensemble over `transitions`."""

            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize parameters and optimizer state for an ensemble of Q-networks.
        rng = hk.PRNGSequence(seed)
        dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_target_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_opt_state = [optimizer.init(p) for p in initial_params]

        # Internalize state.
        self._ensemble = [
            TrainingState(p, tp, o, step=0) for p, tp, o in zip(
                initial_params, initial_target_params, initial_opt_state)
        ]
        self._forward = jax.jit(network.apply)
        self._sgd_step = sgd_step
        self._num_ensemble = num_ensemble
        self._optimizer = optimizer
        self._replay = replay.Replay(capacity=replay_capacity)

        # Agent hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._mask_prob = mask_prob

        # Agent state.
        self._active_head = self._ensemble[0]
        self._total_steps = 0
Example #4
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: hk.Transformed,
        optimizer: optix.InitUpdate,
        batch_size: int,
        epsilon: float,
        rng: hk.PRNGSequence,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
    ):

        # Define loss function.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Computes the standard TD(0) Q-learning loss on batch of transitions."""
            o_tm1, a_tm1, r_t, d_t, o_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(td_error**2)

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Performs an SGD step on a batch of transitions."""
            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize the networks and optimizer.
        dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = network.init(next(rng), dummy_observation)
        initial_target_params = network.init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # This carries the agent state relevant to training.
        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)
        self._sgd_step = sgd_step
        self._forward = jax.jit(network.apply)
        self._replay = replay.Replay(capacity=replay_capacity)

        # Store hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._epsilon = epsilon
        self._total_steps = 0
        self._min_replay_size = min_replay_size
Example #5
0
    def __init__(
        self,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        obs_spec: specs.Array,
        iterator: Iterator[reverb.ReplaySample],
        optimizer: optix.InitUpdate,
        rng: hk.PRNGSequence,
        discount: float = 0.99,
        entropy_cost: float = 0.,
        baseline_cost: float = 1.,
        max_abs_reward: float = np.inf,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
    ):

        # Initialise training state (parameters & optimiser state).
        network = hk.transform(network)
        initial_network_state = hk.transform(initial_state_fn).apply(None)
        initial_params = network.init(next(rng),
                                      jax_utils.zeros_like(obs_spec),
                                      initial_network_state)
        initial_opt_state = optimizer.init(initial_params)

        def loss(params: hk.Params, sample: reverb.ReplaySample):
            """V-trace loss."""

            # Extract the data.
            observations, actions, rewards, discounts, extra = sample.data
            initial_state = tree.map_structure(lambda s: s[0],
                                               extra['core_state'])
            behaviour_logits = extra['logits']

            #
            actions = actions[:-1]  # [T-1]
            rewards = rewards[:-1]  # [T-1]
            discounts = discounts[:-1]  # [T-1]
            rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

            # Unroll current policy over observations.
            net = functools.partial(network.apply, params)
            (logits, values), _ = hk.static_unroll(net, observations,
                                                   initial_state)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = rlax.categorical_importance_sampling_ratios(
                logits[:-1], behaviour_logits[:-1], actions)

            # Critic loss.
            vtrace_returns = rlax.vtrace_td_error_and_advantage(
                v_tm1=values[:-1],
                v_t=values[1:],
                r_t=rewards,
                discount_t=discounts * discount,
                rho_t=rhos)
            critic_loss = jnp.square(vtrace_returns.errors)

            # Policy gradient loss.
            policy_gradient_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=actions,
                adv_t=vtrace_returns.pg_advantage,
                w_t=jnp.ones_like(rewards))

            # Entropy regulariser.
            entropy_loss = rlax.entropy_loss(logits[:-1],
                                             jnp.ones_like(rewards))

            # Combine weighted sum of actor & critic losses.
            mean_loss = jnp.mean(policy_gradient_loss +
                                 baseline_cost * critic_loss +
                                 entropy_cost * entropy_loss)

            return mean_loss

        @jax.jit
        def sgd_step(state: TrainingState, sample: reverb.ReplaySample):
            # Compute gradients and optionally apply clipping.
            batch_loss = jax.vmap(loss, in_axes=(None, 0))
            mean_loss = lambda p, s: jnp.mean(batch_loss(p, s))
            grad_fn = jax.value_and_grad(mean_loss)
            loss_value, gradients = grad_fn(state.params, sample)

            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics

        self._state = TrainingState(params=initial_params,
                                    opt_state=initial_opt_state)

        # Internalise iterator.
        self._iterator = jax_utils.prefetch(iterator)
        self._sgd_step = sgd_step

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)
Example #6
0
  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: RecurrentPolicyValueNet,
      initial_rnn_state: LSTMState,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
      entropy_cost: float = 0.,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state

    # Transform the loss into a pure function.
    loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: AgentState,
                 trajectory: sequence.Trajectory) -> AgentState:
      """Does a step of SGD over a trajectory."""
      gradients, new_rnn_state = jax.grad(
          loss_fn, has_aux=True)(state.params, trajectory,
                                 state.rnn_unroll_state)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return state._replace(
          params=new_params,
          opt_state=new_opt_state,
          rnn_unroll_state=new_rnn_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.without_apply_rng(hk.transform(network, apply_rng=True))
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=obs_spec.dtype)
    initial_params = init(next(rng), dummy_observation, initial_rnn_state)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = AgentState(initial_params, initial_opt_state,
                             initial_rnn_state, initial_rnn_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng
    self._initial_rnn_state = initial_rnn_state