Exemple #1
0
def rnd_update_step(
    state: RNDTrainingState, transitions: types.Transition, loss_fn: RNDLoss,
    optimizer: optax.GradientTransformation
) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]:
    """Run an update steps on the given transitions.

  Args:
    state: The learner state.
    transitions: Transitions to update on.
    loss_fn: The loss function.
    optimizer: The optimizer of the predictor network.

  Returns:
    A new state and metrics.
  """
    loss, grads = jax.value_and_grad(loss_fn)(state.params,
                                              state.target_params,
                                              transitions=transitions)

    update, optimizer_state = optimizer.update(grads, state.optimizer_state)
    params = optax.apply_updates(state.params, update)

    new_state = RNDTrainingState(
        optimizer_state=optimizer_state,
        params=params,
        target_params=state.target_params,
        steps=state.steps + 1,
    )
    return new_state, {'rnd_loss': loss}
Exemple #2
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: PolicyValueNet,
        optimizer: optax.GradientTransformation,
        rng: hk.PRNGSequence,
        sequence_length: int,
        discount: float,
        td_lambda: float,
    ):

        # Define loss function.
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss

        # Transform the loss into a pure function.
        loss_fn = hk.without_apply_rng(hk.transform(loss,
                                                    apply_rng=True)).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     trajectory: sequence.Trajectory) -> TrainingState:
            """Does a step of SGD over a trajectory."""
            gradients = jax.grad(loss_fn)(state.params, trajectory)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)
            return TrainingState(params=new_params, opt_state=new_opt_state)

        # Initialize network parameters and optimiser state.
        init, forward = hk.without_apply_rng(
            hk.transform(network, apply_rng=True))
        dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
        initial_params = init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = TrainingState(initial_params, initial_opt_state)
        self._forward = jax.jit(forward)
        self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
        self._sgd_step = sgd_step
        self._rng = rng
def clip_by_global_norm(max_norm) -> GradientTransformation:
    """Clip updates using their global norm.

    References:
      [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

    Args:
      max_norm: the maximum global norm for an update.

    Returns:
      An (init_fn, update_fn) tuple.
    """

    def init_fn(_):
        return ClipByGlobalNormState()

    def update_fn(updates, state, params=None):
        del params
        g_norm = global_norm(updates)
        trigger = g_norm < max_norm
        updates = jax.tree_map(
            lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
        return updates, state

    return GradientTransformation(init_fn, update_fn)
Exemple #4
0
def sgld_gradient_update(step_size_fn,
                         seed,
                         momentum_decay=0.,
                         preconditioner=None):
    """Optax implementation of the SGLD optimizer.

  If momentum_decay is set to zero, we get the SGLD method [1]. Otherwise,
  we get the underdamped SGLD (SGHMC) method [2].

  Args:
    step_size_fn: a function taking training step as input and producing the
      step size as output.
    seed: int, random seed.
    momentum_decay: float, momentum decay parameter (default: 0).
    preconditioner: Preconditioner, an object representing the preconditioner
      or None; if None, identity preconditioner is used (default: None).  [1]
        "Bayesian Learning via Stochastic Gradient Langevin Dynamics" Max
        Welling, Yee Whye Teh; ICML 2011  [2] "Stochastic Gradient Hamiltonian
        Monte Carlo" Tianqi Chen, Emily B. Fox, Carlos Guestrin; ICML 2014
  """

    if preconditioner is None:
        preconditioner = get_identity_preconditioner()

    def init_fn(params):
        return OptaxSGLDState(count=jnp.zeros([], jnp.int32),
                              rng_key=jax.random.PRNGKey(seed),
                              momentum=jax.tree_map(jnp.zeros_like, params),
                              preconditioner_state=preconditioner.init(params))

    def update_fn(gradient, state, params=None):
        del params
        lr = step_size_fn(state.count)
        lr_sqrt = jnp.sqrt(lr)
        noise_std = jnp.sqrt(2 * (1 - momentum_decay))

        preconditioner_state = preconditioner.update_preconditioner(
            gradient, state.preconditioner_state)

        noise, new_key = tree_utils.normal_like_tree(gradient, state.rng_key)
        noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state)

        def update_momentum(m, g, n):
            return momentum_decay * m + g * lr_sqrt + n * noise_std

        momentum = jax.tree_map(update_momentum, state.momentum, gradient,
                                noise)
        updates = preconditioner.multiply_by_m_inv(momentum,
                                                   preconditioner_state)
        updates = jax.tree_map(lambda m: m * lr_sqrt, updates)
        return updates, OptaxSGLDState(
            count=state.count + 1,
            rng_key=new_key,
            momentum=momentum,
            preconditioner_state=preconditioner_state)

    return GradientTransformation(init_fn, update_fn)
Exemple #5
0
    def __init__(self,
                 direct_rl_learner_factory: Callable[
                     [Any, Iterator[reverb.ReplaySample]], acme.Learner],
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 rnd_network: rnd_networks.RNDNetworks,
                 rng_key: jnp.ndarray,
                 grad_updates_per_batch: int,
                 is_sequence_based: bool,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        self._is_sequence_based = is_sequence_based

        target_key, predictor_key = jax.random.split(rng_key)
        target_params = rnd_network.target.init(target_key)
        predictor_params = rnd_network.predictor.init(predictor_key)
        optimizer_state = optimizer.init(predictor_params)

        self._state = RNDTrainingState(optimizer_state=optimizer_state,
                                       params=predictor_params,
                                       target_params=target_params,
                                       steps=0)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        loss = functools.partial(rnd_loss, networks=rnd_network)
        self._update = functools.partial(rnd_update_step,
                                         loss_fn=loss,
                                         optimizer=optimizer)
        self._update = utils.process_multiple_batches(self._update,
                                                      grad_updates_per_batch)
        self._update = jax.jit(self._update)

        self._get_reward = jax.jit(
            functools.partial(rnd_networks.compute_rnd_reward,
                              networks=rnd_network))

        # Generator expression that works the same as an iterator.
        # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions
        updated_iterator = (self._process_sample(sample)
                            for sample in iterator)

        self._direct_rl_learner = direct_rl_learner_factory(
            rnd_network.direct_rl_networks, updated_iterator)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemple #6
0
  def __init__(self,
               network: hk.Transformed,
               obs_spec: specs.Array,
               optimizer: optax.GradientTransformation,
               rng: hk.PRNGSequence,
               dataset: tf.data.Dataset,
               loss_fn: LossFn = _sparse_categorical_cross_entropy,
               counter: counting.Counter = None,
               logger: loggers.Logger = None):
    """Initializes the learner."""

    def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray:
      # Pull out the data needed for updates.
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data
      del r_t, d_t, o_t
      logits = network.apply(params, o_tm1)
      return jnp.mean(loss_fn(a_tm1, logits))

    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
      """Do a step of SGD."""
      grad_fn = jax.value_and_grad(loss)
      loss_value, gradients = grad_fn(state.params, sample)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      new_state = TrainingState(
          params=new_params, opt_state=new_opt_state, steps=steps)

      # Compute the global norm of the gradients for logging.
      global_gradient_norm = optax.global_norm(gradients)
      fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm}

      return new_state, fetches

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    # Initialise parameters and optimiser state.
    initial_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_opt_state = optimizer.init(initial_params)

    self._state = TrainingState(
        params=initial_params, opt_state=initial_opt_state, steps=0)

    self._sgd_step = jax.jit(sgd_step)
Exemple #7
0
def ail_update_step(
        state: DiscriminatorTrainingState, data: Tuple[types.Transition,
                                                       types.Transition],
        optimizer: optax.GradientTransformation,
        ail_network: ail_networks.AILNetworks, loss_fn: losses.Loss
) -> Tuple[DiscriminatorTrainingState, losses.Metrics]:
    """Run an update steps on the given transitions.

  Args:
    state: The learner state.
    data: Demo and rb transitions.
    optimizer: Discriminator optimizer.
    ail_network: AIL networks.
    loss_fn: Discriminator loss to minimize.

  Returns:
    A new state and metrics.
  """
    demo_transitions, rb_transitions = data
    key, discriminator_key, loss_key = jax.random.split(state.key, 3)

    def compute_loss(
            discriminator_params: networks_lib.Params) -> losses.LossOutput:
        discriminator_fn = functools.partial(
            ail_network.discriminator_network.apply,
            discriminator_params,
            state.policy_params,
            is_training=True,
            rng=discriminator_key)
        return loss_fn(discriminator_fn, state.discriminator_state,
                       demo_transitions, rb_transitions, loss_key)

    loss_grad = jax.grad(compute_loss, has_aux=True)

    grads, (loss,
            new_discriminator_state) = loss_grad(state.discriminator_params)

    update, optimizer_state = optimizer.update(
        grads, state.optimizer_state, params=state.discriminator_params)
    discriminator_params = optax.apply_updates(state.discriminator_params,
                                               update)

    new_state = DiscriminatorTrainingState(
        optimizer_state=optimizer_state,
        discriminator_params=discriminator_params,
        discriminator_state=new_discriminator_state,
        policy_params=state.policy_params,  # Not modified.
        key=key,
        steps=state.steps + 1,
    )
    return new_state, loss
def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation:
    """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc)

    Args:
      weight_decay: a scalar weight decay rate.

    Returns:
      An (init_fn, update_fn) tuple.
    """

    def init_fn(_):
        return AdditiveWeightDecayState()

    def update_fn(updates, state, params):
        updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params)
        return updates, state

    return GradientTransformation(init_fn, update_fn)
Exemple #9
0
def fit(params: optax.Params,
        opt: optax.GradientTransformation) -> optax.Params:
    opt_state = opt.init(params)

    @jax.jit
    def step(params, opt_state, batch, labels):
        (loss_val, accuracy), grads = jax.value_and_grad(loss,
                                                         has_aux=True)(params,
                                                                       batch,
                                                                       labels)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_val, accuracy

    for i, (batch, labels) in enumerate(zip(train_data, train_labels)):
        params, opt_state, loss_val, accuracy = step(params, opt_state, batch,
                                                     labels)
        if i % 100 == 0:
            print(
                f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%"
            )

    return params
Exemple #10
0
    def __init__(self,
                 counter: counting.Counter,
                 direct_rl_learner_factory: Callable[
                     [Iterator[reverb.ReplaySample]], acme.Learner],
                 loss_fn: losses.Loss,
                 iterator: Iterator[AILSample],
                 discriminator_optimizer: optax.GradientTransformation,
                 ail_network: ail_networks.AILNetworks,
                 discriminator_key: networks_lib.PRNGKey,
                 is_sequence_based: bool,
                 num_sgd_steps_per_step: int = 1,
                 policy_variable_name: Optional[str] = None,
                 logger: Optional[loggers.Logger] = None):
        """AIL Learner.

    Args:
      counter: Counter.
      direct_rl_learner_factory: Function that creates the direct RL learner
        when passed a replay sample iterator.
      loss_fn: Discriminator loss.
      iterator: Iterator that returns AILSamples.
      discriminator_optimizer: Discriminator optax optimizer.
      ail_network: AIL networks.
      discriminator_key: RNG key.
      is_sequence_based: If True, a direct rl algorithm is using SequenceAdder
        data format. Otherwise the learner assumes that the direct rl algorithm
        is using NStepTransitionAdder.
      num_sgd_steps_per_step: Number of discriminator gradient updates per step.
      policy_variable_name: The name of the policy variable to retrieve
        direct_rl policy parameters.
      logger: Logger.
    """
        self._is_sequence_based = is_sequence_based

        state_key, networks_key = jax.random.split(discriminator_key)

        # Generator expression that works the same as an iterator.
        # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions
        iterator, direct_rl_iterator = itertools.tee(iterator)
        direct_rl_iterator = (self._process_sample(sample.direct_sample)
                              for sample in direct_rl_iterator)
        self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator)

        self._iterator = iterator

        if policy_variable_name is not None:

            def get_policy_params():
                return self._direct_rl_learner.get_variables(
                    [policy_variable_name])[0]

            self._get_policy_params = get_policy_params

        else:
            self._get_policy_params = lambda: None

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Use the JIT compiler.
        self._update_step = functools.partial(
            ail_update_step,
            optimizer=discriminator_optimizer,
            ail_network=ail_network,
            loss_fn=loss_fn)
        self._update_step = utils.process_multiple_batches(
            self._update_step, num_sgd_steps_per_step)
        self._update_step = jax.jit(self._update_step)

        discriminator_params, discriminator_state = (
            ail_network.discriminator_network.init(networks_key))
        self._state = DiscriminatorTrainingState(
            optimizer_state=discriminator_optimizer.init(discriminator_params),
            discriminator_params=discriminator_params,
            discriminator_state=discriminator_state,
            policy_params=self._get_policy_params(),
            key=state_key,
            steps=0,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        self._get_reward = jax.jit(
            functools.partial(ail_networks.compute_ail_reward,
                              networks=ail_network))
Exemple #11
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        optimizer: optax.GradientTransformation,
        transition_accumulator: Any,
        replay: replay_lib.TransitionReplay,
        shaping_function,
        mask_probability: float,
        num_heads: int,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        grad_error_bound: float,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._mask_probabilities = jnp.array(
            [mask_probability, 1 - mask_probability])
        self._num_heads = num_heads
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).multi_head_output
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).multi_head_output

            # batch by num_heads -> batch by num_heads by num_actions
            mask = jnp.einsum('ij,k->ijk', transitions.mask_t,
                              jnp.ones(q_tm1.shape[-1]))

            masked_q = jnp.multiply(mask, q_tm1)
            masked_q_target = jnp.multiply(mask, q_target_t)

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_target = jnp.reshape(q_target_t,
                                             (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
            )

            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )
            loss = jnp.mean(losses)
            return loss

        def update(rng_key, opt_state, online_params, target_params,
                   transitions):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                                transitions, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)

            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).random_head_q_value[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            return rng_key, a_t

        self._select_action = jax.jit(select_action)
Exemple #12
0
    def __init__(self,
                 networks: td3_networks.TD3Networks,
                 random_key: networks_lib.PRNGKey,
                 discount: float,
                 iterator: Iterator[reverb.ReplaySample],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 twin_critic_optimizer: optax.GradientTransformation,
                 delay: int = 2,
                 target_sigma: float = 0.2,
                 noise_clip: float = 0.5,
                 tau: float = 0.005,
                 use_sarsa_target: bool = False,
                 bc_alpha: Optional[float] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initializes the TD3 learner.

    Args:
      networks: TD3 networks.
      random_key: a key for random number generation.
      discount: discount to use for TD updates
      iterator: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      twin_critic_optimizer: the twin Q-function optimizer.
      delay: ratio of policy updates for critic updates (see TD3),
        delay=2 means 2 updates of the critic for 1 policy update.
      target_sigma: std of zero mean Gaussian added to the action of
        the next_state, for critic evaluation (reducing overestimation bias).
      noise_clip: hard constraint on target noise.
      tau: target parameters smoothing coefficient.
      use_sarsa_target: compute on-policy target using iterator's actions rather
        than sampled actions.
        Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf).
        When set to `True`, `target_policy_params` are unused.
        This is only working when the learner is used as an offline algorithm.
        I.e. TD3Builder does not support adding the SARSA target to the replay
        buffer.
      bc_alpha: bc_alpha: Implements TD3+BC.
        See comments in TD3Config.bc_alpha for details.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'.
    """
        def policy_loss(
            policy_params: networks_lib.Params,
            critic_params: networks_lib.Params,
            transition: types.NestedArray,
        ) -> jnp.ndarray:
            # Computes the discrete policy gradient loss.
            action = networks.policy_network.apply(policy_params,
                                                   transition.observation)
            grad_critic = jax.vmap(jax.grad(networks.critic_network.apply,
                                            argnums=2),
                                   in_axes=(None, 0, 0))
            dq_da = grad_critic(critic_params, transition.observation, action)
            batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0))
            loss = jnp.mean(batch_dpg_learning(action, dq_da))
            if bc_alpha is not None:
                # BC regularization for offline RL
                q_sa = networks.critic_network.apply(critic_params,
                                                     transition.observation,
                                                     action)
                bc_factor = jax.lax.stop_gradient(bc_alpha /
                                                  jnp.mean(jnp.abs(q_sa)))
                loss += jnp.mean(
                    jnp.square(action - transition.action)) / bc_factor
            return loss

        def critic_loss(
            critic_params: networks_lib.Params,
            state: TrainingState,
            transition: types.Transition,
            random_key: jnp.ndarray,
        ):
            # Computes the critic loss.
            q_tm1 = networks.critic_network.apply(critic_params,
                                                  transition.observation,
                                                  transition.action)

            if use_sarsa_target:
                # TODO(b/222674779): use N-steps Trajectories to get the next actions.
                assert 'next_action' in transition.extras, (
                    'next actions should be given as extras for one step RL.')
                action = transition.extras['next_action']
            else:
                action = networks.policy_network.apply(
                    state.target_policy_params, transition.next_observation)
                action = networks.add_policy_noise(action, random_key,
                                                   target_sigma, noise_clip)

            q_t = networks.critic_network.apply(state.target_critic_params,
                                                transition.next_observation,
                                                action)
            twin_q_t = networks.twin_critic_network.apply(
                state.target_twin_critic_params, transition.next_observation,
                action)

            q_t = jnp.minimum(q_t, twin_q_t)

            target_q_tm1 = transition.reward + discount * transition.discount * q_t
            td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1

            return jnp.mean(jnp.square(td_error))

        def update_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            random_key, key_critic, key_twin = jax.random.split(
                state.random_key, 3)

            # Updates on the critic: compute the gradients, and update using
            # Polyak averaging.
            critic_loss_and_grad = jax.value_and_grad(critic_loss)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state, transitions, key_critic)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)
            # In the original authors' implementation the critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_critic_params = optax.incremental_update(
                new_tensors=critic_params,
                old_tensors=state.target_critic_params,
                step_size=tau)

            # Updates on the twin critic: compute the gradients, and update using
            # Polyak averaging.
            twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad(
                state.twin_critic_params, state, transitions, key_twin)
            twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update(
                twin_critic_gradients, state.twin_critic_opt_state)
            twin_critic_params = optax.apply_updates(state.twin_critic_params,
                                                     twin_critic_updates)
            # In the original authors' implementation the twin critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_twin_critic_params = optax.incremental_update(
                new_tensors=twin_critic_params,
                old_tensors=state.target_twin_critic_params,
                step_size=tau)

            # Updates on the policy: compute the gradients, and update using
            # Polyak averaging (if delay enabled, the update might not be applied).
            policy_loss_and_grad = jax.value_and_grad(policy_loss)
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions)

            def update_policy_step():
                policy_updates, policy_opt_state = policy_optimizer.update(
                    policy_gradients, state.policy_opt_state)
                policy_params = optax.apply_updates(state.policy_params,
                                                    policy_updates)
                target_policy_params = optax.incremental_update(
                    new_tensors=policy_params,
                    old_tensors=state.target_policy_params,
                    step_size=tau)
                return policy_params, target_policy_params, policy_opt_state

            # The update on the policy is applied every `delay` steps.
            current_policy_state = (state.policy_params,
                                    state.target_policy_params,
                                    state.policy_opt_state)
            policy_params, target_policy_params, policy_opt_state = jax.lax.cond(
                state.steps % delay == 0,
                lambda _: update_policy_step(),
                lambda _: current_policy_state,
                operand=None)

            steps = state.steps + 1

            new_state = TrainingState(
                policy_params=policy_params,
                critic_params=critic_params,
                twin_critic_params=twin_critic_params,
                target_policy_params=target_policy_params,
                target_critic_params=target_critic_params,
                target_twin_critic_params=target_twin_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                twin_critic_opt_state=twin_critic_opt_state,
                steps=steps,
                random_key=random_key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
                'twin_critic_loss': twin_critic_loss_value,
            }

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Create prefetching dataset iterator.
        self._iterator = iterator

        # Faster sgd step
        update_step = utils.process_multiple_batches(update_step,
                                                     num_sgd_steps_per_step)
        # Use the JIT compiler.
        self._update_step = jax.jit(update_step)

        (key_init_policy, key_init_twin, key_init_target,
         key_state) = jax.random.split(random_key, 4)
        # Create the network parameters and copy into the target network parameters.
        initial_policy_params = networks.policy_network.init(key_init_policy)
        initial_critic_params = networks.critic_network.init(key_init_twin)
        initial_twin_critic_params = networks.twin_critic_network.init(
            key_init_target)

        initial_target_policy_params = initial_policy_params
        initial_target_critic_params = initial_critic_params
        initial_target_twin_critic_params = initial_twin_critic_params

        # Initialize optimizers.
        initial_policy_opt_state = policy_optimizer.init(initial_policy_params)
        initial_critic_opt_state = critic_optimizer.init(initial_critic_params)
        initial_twin_critic_opt_state = twin_critic_optimizer.init(
            initial_twin_critic_params)

        # Create initial state.
        self._state = TrainingState(
            policy_params=initial_policy_params,
            target_policy_params=initial_target_policy_params,
            critic_params=initial_critic_params,
            twin_critic_params=initial_twin_critic_params,
            target_critic_params=initial_target_critic_params,
            target_twin_critic_params=initial_target_twin_critic_params,
            policy_opt_state=initial_policy_opt_state,
            critic_opt_state=initial_critic_opt_state,
            twin_critic_opt_state=initial_twin_critic_opt_state,
            steps=0,
            random_key=key_state)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemple #13
0
    def __init__(self,
                 observation_spec: specs.Array,
                 action_spec: specs.DiscreteArray,
                 network: PolicyValueNet,
                 initial_rnn_state: RNNState,
                 optimizer: optax.GradientTransformation,
                 rng: hk.PRNGSequence,
                 buffer_length: int,
                 discount: float,
                 td_lambda: float,
                 entropy_cost: float = 1.,
                 critic_cost: float = 1.):
        @jax.jit
        def pack(trajectory: buffer.Trajectory) -> List[jnp.ndarray]:
            """Converts a trajectory into an input."""
            observations = trajectory.observations[:, None, ...]

            rewards = jnp.concatenate([
                trajectory.previous_reward,
                jnp.squeeze(trajectory.rewards, -1)
            ], -1)
            rewards = jnp.squeeze(rewards)
            rewards = jnp.expand_dims(rewards, (1, 2))

            previous_action = jax.nn.one_hot(trajectory.previous_action,
                                             action_spec.num_values)
            actions = jax.nn.one_hot(jnp.squeeze(trajectory.actions, 1),
                                     action_spec.num_values)
            actions = jnp.expand_dims(
                jnp.concatenate([previous_action, actions], 0), 1)

            return [observations, rewards, actions]

        @jax.jit
        def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState):
            """"Computes a linear combination of the policy gradient loss and value loss
      and regularizes it with an entropy term."""
            inputs = pack(trajectory)

            # Dyanmically unroll the network. This Haiku utility function unpacks the
            # list of input tensors such that the i^{th} row from each input tensor
            # is presented to the i^{th} unrolled RNN module.
            (logits, values, _, _,
             state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll(
                 network, inputs, rnn_unroll_state)
            trajectory_len = trajectory.actions.shape[0]

            # Compute the combined loss given the output of the model.
            td_errors = rlax.td_lambda(v_tm1=values[:-1, 0],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:, 0],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1, 0],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones(trajectory_len))
            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len)))

            combined_loss = (actor_loss + critic_cost * critic_loss +
                             entropy_cost * entropy_loss)

            return combined_loss, new_rnn_unroll_state

        # Transform the loss into a pure function.
        loss_fn = hk.without_apply_rng(hk.transform(loss,
                                                    apply_rng=True)).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: AgentState,
                     trajectory: buffer.Trajectory) -> AgentState:
            """Performs a step of SGD over a trajectory."""
            gradients, new_rnn_state = jax.grad(loss_fn, has_aux=True)(
                state.params, trajectory, state.rnn_unroll_state)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)
            return state._replace(params=new_params,
                                  opt_state=new_opt_state,
                                  rnn_unroll_state=new_rnn_state)

        # Initialize network parameters.
        init, forward = hk.without_apply_rng(
            hk.transform(network, apply_rng=True))
        dummy_observation = jnp.zeros((1, *observation_spec.shape),
                                      dtype=observation_spec.dtype)
        dummy_reward = jnp.zeros((1, 1, 1))
        dummy_action = jnp.zeros((1, 1, action_spec.num_values))
        inputs = [dummy_observation, dummy_reward, dummy_action]
        initial_params = init(next(rng), inputs, initial_rnn_state)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = AgentState(initial_params, initial_opt_state,
                                 initial_rnn_state, initial_rnn_state)
        self._forward = jax.jit(forward)
        self._buffer = buffer.Buffer(observation_spec, action_spec,
                                     buffer_length)
        self._sgd_step = sgd_step
        self._rng = rng
        self._initial_rnn_state = initial_rnn_state
        self._action_spec = action_spec
Exemple #14
0
    def __init__(self,
                 networks: CRRNetworks,
                 random_key: networks_lib.PRNGKey,
                 discount: float,
                 target_update_period: int,
                 policy_loss_coeff_fn: PolicyLossCoeff,
                 iterator: Iterator[types.Transition],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 grad_updates_per_batch: int = 1,
                 use_sarsa_target: bool = False):
        """Initializes the CRR learner.

    Args:
      networks: CRR networks.
      random_key: a key for random number generation.
      discount: discount to use for TD updates.
      target_update_period: period to update target's parameters.
      policy_loss_coeff_fn: set the loss function for the policy.
      iterator: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      grad_updates_per_batch: how many gradient updates given a sampled batch.
      use_sarsa_target: compute on-policy target using iterator's actions rather
        than sampled actions.
        Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf).
        When set to `True`, `target_policy_params` are unused.
    """

        critic_network = networks.critic_network
        policy_network = networks.policy_network

        def policy_loss(
            policy_params: networks_lib.Params,
            critic_params: networks_lib.Params,
            transition: types.Transition,
            key: networks_lib.PRNGKey,
        ) -> jnp.ndarray:
            # Compute the loss coefficients.
            coeff = policy_loss_coeff_fn(networks, policy_params,
                                         critic_params, transition, key)
            coeff = jax.lax.stop_gradient(coeff)
            # Return the weighted loss.
            dist_params = policy_network.apply(policy_params,
                                               transition.observation)
            logp_action = networks.log_prob(dist_params, transition.action)
            return -jnp.mean(logp_action * coeff)

        def critic_loss(
            critic_params: networks_lib.Params,
            target_policy_params: networks_lib.Params,
            target_critic_params: networks_lib.Params,
            transition: types.Transition,
            key: networks_lib.PRNGKey,
        ):
            # Sample the next action.
            if use_sarsa_target:
                # TODO(b/222674779): use N-steps Trajectories to get the next actions.
                assert 'next_action' in transition.extras, (
                    'next actions should be given as extras for one step RL.')
                next_action = transition.extras['next_action']
            else:
                next_dist_params = policy_network.apply(
                    target_policy_params, transition.next_observation)
                next_action = networks.sample(next_dist_params, key)
            # Calculate the value of the next state and action.
            next_q = critic_network.apply(target_critic_params,
                                          transition.next_observation,
                                          next_action)
            target_q = transition.reward + transition.discount * discount * next_q
            target_q = jax.lax.stop_gradient(target_q)

            q = critic_network.apply(critic_params, transition.observation,
                                     transition.action)
            q_error = q - target_q
            # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
            # TODO(sertan): Replace with a distributional critic. CRR paper states
            # that this may perform better.
            return 0.5 * jnp.mean(jnp.square(q_error))

        policy_loss_and_grad = jax.value_and_grad(policy_loss)
        critic_loss_and_grad = jax.value_and_grad(critic_loss)

        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_policy, key_critic = jax.random.split(state.key, 3)

            # Compute losses and their gradients.
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions,
                key_policy)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state.target_policy_params,
                state.target_critic_params, transitions, key_critic)

            # Get optimizer updates and state.
            policy_updates, policy_opt_state = policy_optimizer.update(
                policy_gradients, state.policy_opt_state)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)

            # Apply optimizer updates to parameters.
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_updates)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)

            steps = state.steps + 1

            # Periodically update target networks.
            target_policy_params, target_critic_params = optax.periodic_update(
                (policy_params, critic_params),
                (state.target_policy_params, state.target_critic_params),
                steps, target_update_period)

            new_state = TrainingState(
                policy_params=policy_params,
                target_policy_params=target_policy_params,
                critic_params=critic_params,
                target_critic_params=target_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                steps=steps,
                key=key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
            }

            return new_state, metrics

        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  grad_updates_per_batch)
        self._sgd_step = jax.jit(sgd_step)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Create prefetching dataset iterator.
        self._iterator = iterator

        # Create the network parameters and copy into the target network parameters.
        key, key_policy, key_critic = jax.random.split(random_key, 3)
        initial_policy_params = policy_network.init(key_policy)
        initial_critic_params = critic_network.init(key_critic)
        initial_target_policy_params = initial_policy_params
        initial_target_critic_params = initial_critic_params

        # Initialize optimizers.
        initial_policy_opt_state = policy_optimizer.init(initial_policy_params)
        initial_critic_opt_state = critic_optimizer.init(initial_critic_params)

        # Create initial state.
        self._state = TrainingState(
            policy_params=initial_policy_params,
            target_policy_params=initial_target_policy_params,
            critic_params=initial_critic_params,
            target_critic_params=initial_target_critic_params,
            policy_opt_state=initial_policy_opt_state,
            critic_opt_state=initial_critic_opt_state,
            steps=0,
            key=key,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemple #15
0
    def __init__(self,
                 batch_size: int,
                 networks: CQLNetworks,
                 random_key: networks_lib.PRNGKey,
                 demonstrations: Iterator[types.Transition],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 tau: float = 0.005,
                 fixed_cql_coefficient: Optional[float] = None,
                 cql_lagrange_threshold: Optional[float] = None,
                 cql_num_samples: int = 10,
                 num_sgd_steps_per_step: int = 1,
                 reward_scale: float = 1.0,
                 discount: float = 0.99,
                 fixed_entropy_coefficient: Optional[float] = None,
                 target_entropy: Optional[float] = 0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the CQL learner.

    Args:
      batch_size: bath size.
      networks: CQL networks.
      random_key: a key for random number generation.
      demonstrations: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      tau: target smoothing coefficient.
      fixed_cql_coefficient: the value for cql coefficient. If None, an adaptive
        coefficient will be used.
      cql_lagrange_threshold: a threshold that controls the adaptive loss for
        the cql coefficient.
      cql_num_samples: number of samples used to compute logsumexp(Q) via
        importance sampling.
      num_sgd_steps_per_step: how many gradient updated to perform per batch.
        batch is split into this many smaller batches, thus should be a multiple
        of num_sgd_steps_per_step
      reward_scale: reward scale.
      discount: discount to use for TD updates.
      fixed_entropy_coefficient: coefficient applied to the entropy bonus. If
        None, an adaptative coefficient will be used.
      target_entropy: Target entropy when using adapdative entropy bonus.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
    """
        adaptive_entropy_coefficient = fixed_entropy_coefficient is None
        action_spec = networks.environment_specs.actions
        if adaptive_entropy_coefficient:
            # sac_alpha is the temperature parameter that determines the relative
            # importance of the entropy term versus the reward.
            log_sac_alpha = jnp.asarray(0., dtype=jnp.float32)
            alpha_optimizer = optax.adam(learning_rate=3e-4)
            alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha)
        else:
            if target_entropy:
                raise ValueError('target_entropy should not be set when '
                                 'fixed_entropy_coefficient is provided')

        adaptive_cql_coefficient = fixed_cql_coefficient is None
        if adaptive_cql_coefficient:
            log_cql_alpha = jnp.asarray(0., dtype=jnp.float32)
            cql_optimizer = optax.adam(learning_rate=3e-4)
            cql_optimizer_state = cql_optimizer.init(log_cql_alpha)
        else:
            if cql_lagrange_threshold:
                raise ValueError(
                    'cql_lagrange_threshold should not be set when '
                    'fixed_cql_coefficient is provided')

        def alpha_loss(log_sac_alpha: jnp.ndarray,
                       policy_params: networks_lib.Params,
                       transitions: types.Transition,
                       key: jnp.ndarray) -> jnp.ndarray:
            """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            sac_alpha = jnp.exp(log_sac_alpha)
            sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob -
                                                               target_entropy)
            return jnp.mean(sac_alpha_loss)

        def sac_critic_loss(q_old_action: jnp.ndarray,
                            policy_params: networks_lib.Params,
                            target_critic_params: networks_lib.Params,
                            sac_alpha: jnp.ndarray,
                            transitions: types.Transition,
                            key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the SAC part of the loss."""
            next_dist_params = networks.policy_network.apply(
                policy_params, transitions.next_observation)
            next_action = networks.sample(next_dist_params, key)
            next_log_prob = networks.log_prob(next_dist_params, next_action)
            next_q = networks.critic_network.apply(
                target_critic_params, transitions.next_observation,
                next_action)
            next_v = jnp.min(next_q, axis=-1) - sac_alpha * next_log_prob
            target_q = jax.lax.stop_gradient(
                transitions.reward * reward_scale +
                transitions.discount * discount * next_v)
            return jnp.mean(
                jnp.square(q_old_action - jnp.expand_dims(target_q, -1)))

        def batched_critic(actions: jnp.ndarray,
                           critic_params: networks_lib.Params,
                           observation: jnp.ndarray) -> jnp.ndarray:
            """Applies the critic network to a batch of sampled actions."""
            actions = jax.lax.stop_gradient(actions)
            tiled_actions = jnp.reshape(actions,
                                        (batch_size * cql_num_samples, -1))
            tiled_states = jnp.tile(observation, [cql_num_samples, 1])
            tiled_q = networks.critic_network.apply(critic_params,
                                                    tiled_states,
                                                    tiled_actions)
            return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1))

        def cql_critic_loss(q_old_action: jnp.ndarray,
                            critic_params: networks_lib.Params,
                            policy_params: networks_lib.Params,
                            transitions: types.Transition,
                            key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the CQL part of the loss."""
            # The CQL part of the loss is
            #     logsumexp(Q(s,·)) - Q(s,a),
            # where s is the currrent state, and a the action in the dataset (so
            # Q(s,a) is simply q_old_action.
            # We need to estimate logsumexp(Q). This is done with importance sampling
            # (IS). This function implements the unlabeled equation page 29, Appx. F,
            # in https://arxiv.org/abs/2006.04779.
            # Here, IS is done with the uniform distribution and the policy in the
            # current state s. In their implementation, the authors also add the
            # policy in the transiting state s':
            # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py,
            # (l. 233-236).

            key_policy, key_policy_next, key_uniform = jax.random.split(key, 3)

            def sampled_q(obs, key):
                actions, log_probs = apply_and_sample_n(
                    key, networks, policy_params, obs, cql_num_samples)
                return batched_critic(
                    actions, critic_params,
                    transitions.observation) - jax.lax.stop_gradient(
                        jnp.expand_dims(log_probs, -1))

            # Sample wrt policy in s
            sampled_q_from_policy = sampled_q(transitions.observation,
                                              key_policy)

            # Sample wrt policy in s'
            sampled_q_from_policy_next = sampled_q(
                transitions.next_observation, key_policy_next)

            # Sample wrt uniform
            actions_uniform = jax.random.uniform(
                key_uniform, (cql_num_samples, batch_size) + action_spec.shape,
                minval=action_spec.minimum,
                maxval=action_spec.maximum)
            log_prob_uniform = -jnp.sum(
                jnp.log(action_spec.maximum - action_spec.minimum))
            sampled_q_from_uniform = (batched_critic(
                actions_uniform, critic_params, transitions.observation) -
                                      log_prob_uniform)

            # Combine the samplings
            combined = jnp.concatenate(
                (sampled_q_from_uniform, sampled_q_from_policy,
                 sampled_q_from_policy_next),
                axis=0)
            lse_q = jax.nn.logsumexp(combined,
                                     axis=0,
                                     b=1. / (3 * cql_num_samples))

            return jnp.mean(lse_q - q_old_action)

        def critic_loss(critic_params: networks_lib.Params,
                        policy_params: networks_lib.Params,
                        target_critic_params: networks_lib.Params,
                        sac_alpha: jnp.ndarray, cql_alpha: jnp.ndarray,
                        transitions: types.Transition,
                        key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the full critic loss."""
            key_cql, key_sac = jax.random.split(key, 2)
            q_old_action = networks.critic_network.apply(
                critic_params, transitions.observation, transitions.action)
            cql_loss = cql_critic_loss(q_old_action, critic_params,
                                       policy_params, transitions, key_cql)
            sac_loss = sac_critic_loss(q_old_action, policy_params,
                                       target_critic_params, sac_alpha,
                                       transitions, key_sac)
            return cql_alpha * cql_loss + sac_loss

        def cql_lagrange_loss(log_cql_alpha: jnp.ndarray,
                              critic_params: networks_lib.Params,
                              policy_params: networks_lib.Params,
                              transitions: types.Transition,
                              key: jnp.ndarray) -> jnp.ndarray:
            """Computes the loss that optimizes the cql coefficient."""
            cql_alpha = jnp.exp(log_cql_alpha)
            q_old_action = networks.critic_network.apply(
                critic_params, transitions.observation, transitions.action)
            return -cql_alpha * (cql_critic_loss(
                q_old_action, critic_params, policy_params, transitions, key) -
                                 cql_lagrange_threshold)

        def actor_loss(policy_params: networks_lib.Params,
                       critic_params: networks_lib.Params,
                       sac_alpha: jnp.ndarray, transitions: types.Transition,
                       key: jnp.ndarray) -> jnp.ndarray:
            """Computes the loss for the policy."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            q_action = networks.critic_network.apply(critic_params,
                                                     transitions.observation,
                                                     action)
            min_q = jnp.min(q_action, axis=-1)
            return jnp.mean(sac_alpha * log_prob - min_q)

        alpha_grad = jax.value_and_grad(alpha_loss)
        cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss)
        critic_grad = jax.value_and_grad(critic_loss)
        actor_grad = jax.value_and_grad(actor_loss)

        def update_step(
            state: TrainingState,
            rb_transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)

            if adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha,
                                                     state.policy_params,
                                                     rb_transitions, key_alpha)
                sac_alpha = jnp.exp(state.log_sac_alpha)
            else:
                sac_alpha = fixed_entropy_coefficient

            if adaptive_cql_coefficient:
                cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad(
                    state.log_cql_alpha, state.critic_params,
                    state.policy_params, rb_transitions, key_critic)
                cql_lagrange_grads = jnp.clip(cql_lagrange_grads,
                                              -_CQL_GRAD_CLIPPING_VALUE,
                                              _CQL_GRAD_CLIPPING_VALUE)
                cql_alpha = jnp.exp(state.log_cql_alpha)
                cql_alpha = jnp.clip(cql_alpha,
                                     a_min=0.,
                                     a_max=_CQL_COEFFICIENT_MAX_VALUE)
            else:
                cql_alpha = fixed_cql_coefficient

            critic_loss, critic_grads = critic_grad(state.critic_params,
                                                    state.policy_params,
                                                    state.target_critic_params,
                                                    sac_alpha, cql_alpha,
                                                    rb_transitions, key_critic)
            actor_loss, actor_grads = actor_grad(state.policy_params,
                                                 state.critic_params,
                                                 sac_alpha, rb_transitions,
                                                 key_actor)

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            critic_update, critic_optimizer_state = critic_optimizer.update(
                critic_grads, state.critic_optimizer_state)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_update)

            new_target_critic_params = jax.tree_multimap(
                lambda x, y: x * (1 - tau) + y * tau,
                state.target_critic_params, critic_params)

            metrics = {
                'critic_loss': critic_loss,
                'actor_loss': actor_loss,
            }

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                critic_optimizer_state=critic_optimizer_state,
                policy_params=policy_params,
                critic_params=critic_params,
                target_critic_params=new_target_critic_params,
                key=key,
                steps=state.steps + 1,
            )
            if adaptive_entropy_coefficient:
                # Apply sac_alpha gradients
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                log_sac_alpha = optax.apply_updates(state.log_sac_alpha,
                                                    alpha_update)
                metrics.update({
                    'alpha_loss': alpha_loss,
                    'sac_alpha': jnp.exp(log_sac_alpha),
                })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    log_sac_alpha=log_sac_alpha)

            if adaptive_cql_coefficient:
                # Apply cql coeff gradients
                cql_update, cql_optimizer_state = cql_optimizer.update(
                    cql_lagrange_grads, state.cql_optimizer_state)
                log_cql_alpha = optax.apply_updates(state.log_cql_alpha,
                                                    cql_update)
                metrics.update({
                    'cql_lagrange_loss': cql_lagrange_loss,
                    'cql_alpha': jnp.exp(log_cql_alpha),
                })
                new_state = new_state._replace(
                    cql_optimizer_state=cql_optimizer_state,
                    log_cql_alpha=log_cql_alpha)

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Iterator on demonstration transitions.
        self._demonstrations = demonstrations

        # Use the JIT compiler.
        self._update_step = utils.process_multiple_batches(
            update_step, num_sgd_steps_per_step)
        self._update_step = jax.jit(self._update_step)

        # Create initial state.
        key_policy, key_q, training_state_key = jax.random.split(random_key, 3)
        del random_key
        policy_params = networks.policy_network.init(key_policy)
        policy_optimizer_state = policy_optimizer.init(policy_params)
        critic_params = networks.critic_network.init(key_q)
        critic_optimizer_state = critic_optimizer.init(critic_params)

        self._state = TrainingState(
            policy_optimizer_state=policy_optimizer_state,
            critic_optimizer_state=critic_optimizer_state,
            policy_params=policy_params,
            critic_params=critic_params,
            target_critic_params=critic_params,
            key=training_state_key,
            steps=0)

        if adaptive_entropy_coefficient:
            self._state = self._state._replace(
                alpha_optimizer_state=alpha_optimizer_state,
                log_sac_alpha=log_sac_alpha)
        if adaptive_cql_coefficient:
            self._state = self._state._replace(
                cql_optimizer_state=cql_optimizer_state,
                log_cql_alpha=log_cql_alpha)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemple #16
0
  def __init__(self,
               network: hk.Transformed,
               obs_spec: specs.Array,
               discount: float,
               importance_sampling_exponent: float,
               target_update_period: int,
               iterator: Iterator[reverb.ReplaySample],
               optimizer: optax.GradientTransformation,
               rng: hk.PRNGSequence,
               max_abs_reward: float = 1.,
               huber_loss_parameter: float = 1.,
               replay_client: reverb.Client = None,
               counter: counting.Counter = None,
               logger: loggers.Logger = None):
    """Initializes the learner."""

    def loss(params: hk.Params, target_params: hk.Params,
             sample: reverb.ReplaySample):
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data
      keys, probs = sample.info[:2]

      # Forward pass.
      q_tm1 = network.apply(params, o_tm1)
      q_t_value = network.apply(target_params, o_t)
      q_t_selector = network.apply(params, o_t)

      # Cast and clip rewards.
      d_t = (d_t * discount).astype(jnp.float32)
      r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32)

      # Compute double Q-learning n-step TD-error.
      batch_error = jax.vmap(rlax.double_q_learning)
      td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector)
      batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

      # Importance weighting.
      importance_weights = (1. / probs).astype(jnp.float32)
      importance_weights **= importance_sampling_exponent
      importance_weights /= jnp.max(importance_weights)

      # Reweight.
      mean_loss = jnp.mean(importance_weights * batch_loss)  # []

      priorities = jnp.abs(td_error).astype(jnp.float64)

      return mean_loss, (keys, priorities)

    def sgd_step(
        state: TrainingState,
        samples: reverb.ReplaySample) -> Tuple[TrainingState, LearnerOutputs]:
      grad_fn = jax.grad(loss, has_aux=True)
      gradients, (keys, priorities) = grad_fn(state.params, state.target_params,
                                              samples)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_params = rlax.periodic_update(
          new_params, state.target_params, steps, self._target_update_period)

      new_state = TrainingState(
          params=new_params,
          target_params=target_params,
          opt_state=new_opt_state,
          steps=steps)

      outputs = LearnerOutputs(keys=keys, priorities=priorities)

      return new_state, outputs

    def update_priorities(outputs: LearnerOutputs):
      replay_client.mutate_priorities(
          table=adders.DEFAULT_PRIORITY_TABLE,
          updates=dict(zip(outputs.keys, outputs.priorities)))

    # Internalise agent components (replay buffer, networks, optimizer).
    self._replay_client = replay_client
    self._iterator = utils.prefetch(iterator)

    # Internalise the hyperparameters.
    self._target_update_period = target_update_period

    # Internalise logging/counting objects.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Initialise parameters and optimiser state.
    initial_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_target_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_opt_state = optimizer.init(initial_params)

    self._state = TrainingState(
        params=initial_params,
        target_params=initial_target_params,
        opt_state=initial_opt_state,
        steps=0)

    self._forward = jax.jit(network.apply)
    self._sgd_step = jax.jit(sgd_step)
    self._async_priority_updater = async_utils.AsyncExecutor(update_priorities)
Exemple #17
0
def train(network_def: nn.Module,
          optim: optax.GradientTransformation,
          alpha_optim: optax.GradientTransformation,
          optimizer_state: jnp.ndarray,
          alpha_optimizer_state: jnp.ndarray,
          network_params: flax.core.FrozenDict,
          target_params: flax.core.FrozenDict,
          log_alpha: jnp.ndarray,
          key: jnp.ndarray,
          states: jnp.ndarray,
          actions: jnp.ndarray,
          next_states: jnp.ndarray,
          rewards: jnp.ndarray,
          terminals: jnp.ndarray,
          cumulative_gamma: float,
          target_entropy: float,
          reward_scale_factor: float) -> Mapping[str, Any]:
  """Run the training step.

  Returns a list of updated values and losses.

  Args:
    network_def: The SAC network definition.
    optim: The SAC optimizer (which also wraps the SAC parameters).
    alpha_optim: The optimizer for alpha.
    optimizer_state: The SAC optimizer state.
    alpha_optimizer_state: The alpha optimizer state.
    network_params: Parameters for SAC's online network.
    target_params: The parameters for SAC's target network.
    log_alpha: Parameters for alpha network.
    key: An rng key to use for random action selection.
    states: A batch of states.
    actions: A batch of actions.
    next_states: A batch of next states.
    rewards: A batch of rewards.
    terminals: A batch of terminals.
    cumulative_gamma: The discount factor to use.
    target_entropy: The target entropy for the agent.
    reward_scale_factor: A factor by which to scale rewards.

  Returns:
    A mapping from string keys to values, including updated optimizers and
      training statistics.
  """
  # Get the models from all the optimizers.
  frozen_params = network_params  # For use in loss_fn without apply gradients

  batch_size = states.shape[0]
  actions = jnp.reshape(actions, (batch_size, -1))  # Flatten

  def loss_fn(
      params: flax.core.FrozenDict, log_alpha: flax.core.FrozenDict,
      state: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray,
      next_state: jnp.ndarray, terminal: jnp.ndarray,
      rng: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
    """Calculates the loss for one transition.

    Args:
      params: Parameters for the SAC network.
      log_alpha: SAC's log_alpha parameter.
      state: A single state vector.
      action: A single action vector.
      reward: A reward scalar.
      next_state: A next state vector.
      terminal: A terminal scalar.
      rng: An RNG key to use for sampling actions.

    Returns:
      A tuple containing 1) the combined SAC loss and 2) a mapping containing
        statistics from the loss step.
    """
    rng1, rng2 = jax.random.split(rng, 2)

    # J_Q(\theta) from equation (5) in paper.
    q_value_1, q_value_2 = network_def.apply(
        params, state, action, method=network_def.critic)
    q_value_1 = jnp.squeeze(q_value_1)
    q_value_2 = jnp.squeeze(q_value_2)

    target_outputs = network_def.apply(target_params, next_state, rng1, True)
    target_q_value_1, target_q_value_2 = target_outputs.critic
    target_q_value = jnp.squeeze(
        jnp.minimum(target_q_value_1, target_q_value_2))

    alpha_value = jnp.exp(log_alpha)
    log_prob = target_outputs.actor.log_probability
    target = reward_scale_factor * reward + cumulative_gamma * (
        target_q_value - alpha_value * log_prob) * (1. - terminal)
    target = jax.lax.stop_gradient(target)
    critic_loss_1 = losses.mse_loss(q_value_1, target)
    critic_loss_2 = losses.mse_loss(q_value_2, target)
    critic_loss = jnp.mean(critic_loss_1 + critic_loss_2)

    # J_{\pi}(\phi) from equation (9) in paper.
    mean_action, sampled_action, action_log_prob = network_def.apply(
        params, state, rng2, method=network_def.actor)

    # We use frozen_params so that gradients can flow back to the actor without
    # being used to update the critic.
    q_value_no_grad_1, q_value_no_grad_2 = network_def.apply(
        frozen_params, state, sampled_action, method=network_def.critic)
    no_grad_q_value = jnp.squeeze(
        jnp.minimum(q_value_no_grad_1, q_value_no_grad_2))
    alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha))
    policy_loss = jnp.mean(alpha_value * action_log_prob - no_grad_q_value)

    # J(\alpha) from equation (18) in paper.
    entropy_diff = -action_log_prob - target_entropy
    alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diff))

    # Giving a smaller weight to the critic empirically gives better results
    combined_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss
    return combined_loss, {
        'critic_loss': critic_loss,
        'policy_loss': policy_loss,
        'alpha_loss': alpha_loss,
        'critic_value_1': q_value_1,
        'critic_value_2': q_value_2,
        'target_value_1': target_q_value_1,
        'target_value_2': target_q_value_2,
        'mean_action': mean_action
    }

  grad_fn = jax.vmap(
      jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True),
      in_axes=(None, None, 0, 0, 0, 0, 0, 0))

  rng = jnp.stack(jax.random.split(key, num=batch_size))
  (_, aux_vars), gradients = grad_fn(network_params, log_alpha, states, actions,
                                     rewards, next_states, terminals, rng)

  # This calculates the mean gradient/aux_vars using the individual
  # gradients/aux_vars from each item in the batch.
  gradients = jax.tree_map(functools.partial(jnp.mean, axis=0), gradients)
  aux_vars = jax.tree_map(functools.partial(jnp.mean, axis=0), aux_vars)
  network_gradient, alpha_gradient = gradients

  # Apply gradients to all the optimizers.
  updates, optimizer_state = optim.update(network_gradient, optimizer_state,
                                          params=network_params)
  network_params = optax.apply_updates(network_params, updates)
  alpha_updates, alpha_optimizer_state = alpha_optim.update(
      alpha_gradient, alpha_optimizer_state, params=log_alpha)
  log_alpha = optax.apply_updates(log_alpha, alpha_updates)

  # Compile everything in a dict.
  returns = {
      'network_params': network_params,
      'log_alpha': log_alpha,
      'optimizer_state': optimizer_state,
      'alpha_optimizer_state': alpha_optimizer_state,
      'Losses/Critic': aux_vars['critic_loss'],
      'Losses/Actor': aux_vars['policy_loss'],
      'Losses/Alpha': aux_vars['alpha_loss'],
      'Values/CriticValues1': jnp.mean(aux_vars['critic_value_1']),
      'Values/CriticValues2': jnp.mean(aux_vars['critic_value_2']),
      'Values/TargetValues1': jnp.mean(aux_vars['target_value_1']),
      'Values/TargetValues2': jnp.mean(aux_vars['target_value_2']),
      'Values/Alpha': jnp.exp(log_alpha),
  }
  for i, a in enumerate(aux_vars['mean_action']):
    returns.update({f'Values/MeanActions{i}': a})
  return returns
Exemple #18
0
  def __init__(self,
               network: networks_lib.FeedForwardNetwork,
               random_key: networks_lib.PRNGKey,
               loss_fn: losses.Loss,
               optimizer: optax.GradientTransformation,
               demonstrations: Iterator[types.Transition],
               num_sgd_steps_per_step: int,
               logger: Optional[loggers.Logger] = None,
               counter: Optional[counting.Counter] = None):
    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      loss_and_grad = jax.value_and_grad(loss_fn, argnums=1)

      # Compute losses and their gradients.
      key, key_input = jax.random.split(state.key)
      loss_value, gradients = loss_and_grad(network.apply, state.policy_params,
                                            key_input, transitions)

      policy_update, optimizer_state = optimizer.update(gradients,
                                                        state.optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      new_state = TrainingState(
          optimizer_state=optimizer_state,
          policy_params=policy_params,
          key=key,
          steps=state.steps + 1,
      )
      metrics = {
          'loss': loss_value,
          'gradient_norm': optax.global_norm(gradients)
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter(prefix='learner')
    self._logger = logger or loggers.make_default_logger(
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

    # Iterator on demonstration transitions.
    self._demonstrations = demonstrations

    # Split the input batch to `num_sgd_steps_per_step` minibatches in order
    # to achieve better performance on accelerators.
    self._sgd_step = jax.jit(utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step))

    random_key, init_key = jax.random.split(random_key)
    policy_params = network.init(init_key)
    optimizer_state = optimizer.init(policy_params)

    # Create initial state.
    self._state = TrainingState(
        optimizer_state=optimizer_state,
        policy_params=policy_params,
        key=random_key,
        steps=0,
    )

    self._timestamp = None
Exemple #19
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: Callable[[jnp.ndarray], jnp.ndarray],
        optimizer: optax.GradientTransformation,
        batch_size: int,
        epsilon: float,
        rng: hk.PRNGSequence,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
    ):
        # Transform the (impure) network into a pure function.
        network = hk.without_apply_rng(hk.transform(network, apply_rng=True))

        # Define loss function.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Computes the standard TD(0) Q-learning loss on batch of transitions."""
            o_tm1, a_tm1, r_t, d_t, o_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(td_error**2)

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Performs an SGD step on a batch of transitions."""
            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize the networks and optimizer.
        dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = network.init(next(rng), dummy_observation)
        initial_target_params = network.init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # This carries the agent state relevant to training.
        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)
        self._sgd_step = sgd_step
        self._forward = jax.jit(network.apply)
        self._replay = replay.Replay(capacity=replay_capacity)

        # Store hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._epsilon = epsilon
        self._total_steps = 0
        self._min_replay_size = min_replay_size
Exemple #20
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        def postprocess_aux(extra: LossExtra) -> LossExtra:
            reverb_update = jax.tree_map(
                lambda a: jnp.reshape(a, (-1, *a.shape[2:])),
                extra.reverb_update)
            return extra._replace(metrics=jax.tree_map(jnp.mean,
                                                       extra.metrics),
                                  reverb_update=reverb_update)

        self._num_sgd_steps_per_step = num_sgd_steps_per_step
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step,
                                                  postprocess_aux)
        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        # Initialize the network parameters
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params)
        initial_target_params = self.network.init(key_target)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: ReverbUpdate) -> None:
            if replay_client is None:
                return
            keys, priorities = tree.map_structure(
                utils.fetch_devicearray,
                (reverb_update.keys, reverb_update.priorities))
            replay_client.mutate_priorities(table=replay_table_name,
                                            updates=dict(zip(keys,
                                                             priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #21
0
    def __init__(self,
                 observation_spec: specs.Array,
                 action_spec: specs.DiscreteArray,
                 network: PolicyValueNet,
                 optimizer: optax.GradientTransformation,
                 rng: hk.PRNGSequence,
                 buffer_length: int,
                 discount: float,
                 td_lambda: float,
                 entropy_cost: float = 1.,
                 critic_cost: float = 1.):
        @jax.jit
        def pack(trajectory: buffer.Trajectory) -> List[jnp.ndarray]:
            """Converts a trajectory into an input."""
            observations = trajectory.observations[:, None, ...]

            rewards = jnp.concatenate([
                trajectory.previous_reward,
                jnp.squeeze(trajectory.rewards, -1)
            ], -1)
            rewards = jnp.expand_dims(rewards, (1, 2))

            previous_action = jax.nn.one_hot(trajectory.previous_action,
                                             action_spec.num_values)
            actions = jax.nn.one_hot(jnp.squeeze(trajectory.actions, 1),
                                     action_spec.num_values)
            actions = jnp.expand_dims(
                jnp.concatenate([previous_action, actions], 0), 1)

            return observations, rewards, actions

        @jax.jit
        def loss(trajectory: buffer.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            observations, rewards, actions = pack(trajectory)
            logits, values, _, _, _ = network(observations, rewards, actions)

            td_errors = rlax.td_lambda(v_tm1=values[:-1],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors)))

            return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss

        # Transform the loss into a pure function.
        loss_fn = hk.without_apply_rng(hk.transform(loss,
                                                    apply_rng=True)).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: AgentState,
                     trajectory: buffer.Trajectory) -> AgentState:
            """Performs a step of SGD over a trajectory."""
            gradients = jax.grad(loss_fn)(state.params, trajectory)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)
            return AgentState(params=new_params, opt_state=new_opt_state)

        # Initialize network parameters and optimiser state.
        init, forward = hk.without_apply_rng(
            hk.transform(network, apply_rng=True))
        dummy_observation = jnp.zeros((1, *observation_spec.shape),
                                      dtype=observation_spec.dtype)
        dummy_reward = jnp.zeros((1, 1, 1))
        dummy_action = jnp.zeros((1, 1, action_spec.num_values))
        initial_params = init(next(rng), dummy_observation, dummy_reward,
                              dummy_action)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = AgentState(initial_params, initial_opt_state)
        self._forward = jax.jit(forward)
        self._buffer = buffer.Buffer(observation_spec, action_spec,
                                     buffer_length)
        self._sgd_step = sgd_step
        self._rng = rng
        self._action_spec = action_spec
Exemple #22
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        support: jnp.ndarray,
        optimizer: optax.GradientTransformation,
        transition_accumulator: Any,
        replay: replay_lib.TransitionReplay,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.
        self._statistics = {'state_value': np.nan}

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key = jax.random.split(rng_key, 3)
            logits_q_tm1 = network.apply(online_params, online_key,
                                         transitions.s_tm1).q_logits
            logits_target_q_t = network.apply(target_params, target_key,
                                              transitions.s_t).q_logits
            losses = _batch_categorical_q_learning(
                support,
                logits_q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                support,
                logits_target_q_t,
            )
            chex.assert_shape(losses, (self._batch_size, ))
            loss = jnp.mean(losses)
            return loss

        def update(rng_key, opt_state, online_params, target_params,
                   transitions):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                                transitions, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).q_values[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            v_t = jnp.max(q_t, axis=-1)
            return rng_key, a_t, v_t

        self._select_action = jax.jit(select_action)
Exemple #23
0
def solve_dual_train(
    env: Dict[int, DualOp],
    dual_state: ConfigDict,
    opt: optax.GradientTransformation,
    inner_opt: InnerMaxStrategy,
    dual_params: Params,
    spec_type: verify_utils.SpecType,
    dual_params_types: ParamsTypes,
    logger: Callable[[int, Mapping[str, Any]], None],
    key: jnp.array,
    num_steps: int,
    affine_before_relu: bool,
    device_type=None,
    merge_problems: Optional[Dict[int, int]] = None,
    block_to_time: bool = False,
) -> ConfigDict:
    """Compute verified upper bound via functional lagrangian relaxation.

  Args:
    env: Lagrangian computations for each contributing graph node.
    dual_state: state of the dual problem.
    opt: an optimizer for the outer Lagrangian parameters.
    inner_opt: inner optimization strategy for training.
    dual_params: dual parameters to be minimized via gradient-based
      optimization.
    spec_type: Specification type, adversarial or uncertainty specification.
    dual_params_types: types of inequality encoded by the corresponding
      dual_params.
    logger: logging function.
    key: jax.random.PRNGKey.
    num_steps: total number of outer optimization steps.
    affine_before_relu: whether layer ordering uses the affine layer before the
      ReLU.
    device_type: string, used to clamp to a particular hardware device. Default
      None uses JAX default device placement.
    merge_problems: the key of the dictionary corresponds to the index of the
      layer to begin the merge, and the associated value corresponds to the
      number of consecutive layers to be merged with it.
      For example, `{0: 2, 2: 3}` will merge together layer 0 and 1, as well as
        layers 2, 3 and 4.
    block_to_time: whether to block computations at the end of each iteration to
      account for asynchronicity dispatch when timing.

  Returns:
    dual_state: new state of the dual problem.
    info: various information for logging / debugging.
  """
    assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type'

    # create dual functions
    loss_func = dual_build.build_dual_fun(
        env=env,
        lagrangian_form=dual_params_types.lagrangian_form,
        inner_opt=inner_opt,
        merge_problems=merge_problems,
        affine_before_relu=affine_before_relu,
        spec_type=spec_type)

    value_and_grad = jax.value_and_grad(loss_func, has_aux=True)

    def grad_step(params, opt_state, key, step):
        (loss_val, stats), g = value_and_grad(params, key, step)
        updates, new_opt_state = opt.update(g, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss_val, stats

    # Some solvers (e.g. MIP) cannot be jitted and run on CPU only
    if inner_opt.jittable:
        grad_step = jax.jit(grad_step, backend=device_type)

    dual_state.step = 0
    dual_state.key = key
    dual_state.opt_state = opt.init(dual_params)
    dual_state.dual_params = dual_params
    dual_state.loss = 0.0

    dual_state.best_loss = jnp.inf
    dual_state.best_dual_params = dual_params

    # optimize the dual (Lagrange) parameters with a gradient-based optimizer
    while dual_state.step < num_steps:
        key_step, dual_state.key = jax.random.split(dual_state.key)
        start_time = time.time()
        dual_params, dual_state.opt_state, dual_state.loss, stats = grad_step(
            dual_state.dual_params, dual_state.opt_state, key_step,
            dual_state.step)
        dual_params = dual_build.project_dual(dual_params, dual_params_types)
        if dual_state.loss <= dual_state.best_loss:
            dual_state.best_loss = dual_state.loss
            # store value from previous iteration as loss corresponds to those params
            dual_state.best_dual_params = dual_state.dual_params
        dual_state.dual_params = dual_params  # projected dual params
        if block_to_time:
            dual_state.loss.block_until_ready()  # asynchronous dispatch
        stats['time_per_iteration'] = time.time() - start_time
        stats['best_loss'] = dual_state.best_loss
        stats['dual_params_norm'] = optax.global_norm(dual_state.dual_params)

        logger(dual_state.step, stats)

        dual_state.step += 1

    return dual_state
Exemple #24
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: Callable[[jnp.ndarray], jnp.ndarray],
        num_ensemble: int,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: optax.GradientTransformation,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = 1,
    ):
        # Transform the (impure) network into a pure function.
        network = hk.without_apply_rng(hk.transform(network, apply_rng=True))

        # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Q-learning loss with added reward noise + half-in bootstrap."""
            o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            r_t += noise_scale * z_t
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(m_t * td_error**2)

        # Define update function for each member of ensemble..
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Does a step of SGD for the whole ensemble over `transitions`."""

            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize parameters and optimizer state for an ensemble of Q-networks.
        rng = hk.PRNGSequence(seed)
        dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_target_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_opt_state = [optimizer.init(p) for p in initial_params]

        # Internalize state.
        self._ensemble = [
            TrainingState(p, tp, o, step=0) for p, tp, o in zip(
                initial_params, initial_target_params, initial_opt_state)
        ]
        self._forward = jax.jit(network.apply)
        self._sgd_step = sgd_step
        self._num_ensemble = num_ensemble
        self._optimizer = optimizer
        self._replay = replay.Replay(capacity=replay_capacity)

        # Agent hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._mask_prob = mask_prob

        # Agent state.
        self._active_head = self._ensemble[0]
        self._total_steps = 0
Exemple #25
0
    def __init__(self,
                 unroll: networks_lib.FeedForwardNetwork,
                 initial_state: networks_lib.FeedForwardNetwork,
                 batch_size: int,
                 random_key: networks_lib.PRNGKey,
                 burn_in_length: int,
                 discount: float,
                 importance_sampling_exponent: float,
                 max_priority_weight: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 bootstrap_n: int = 5,
                 tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR,
                 clip_rewards: bool = False,
                 max_abs_reward: float = 1.,
                 use_core_state: bool = True,
                 prefetch_size: int = 2,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the learner."""

        random_key, key_initial_1, key_initial_2 = jax.random.split(
            random_key, 3)
        initial_state_params = initial_state.init(key_initial_1, batch_size)
        initial_state = initial_state.apply(initial_state_params,
                                            key_initial_2, batch_size)

        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}

        def update_priorities(keys_and_priorities: Tuple[jnp.ndarray,
                                                         jnp.ndarray]):
            keys, priorities = keys_and_priorities
            keys, priorities = tree.map_structure(
                # Fetch array and combine device and batch dimensions.
                lambda x: utils.fetch_devicearray(x).reshape(
                    (-1, ) + x.shape[2:]),
                (keys, priorities))
            replay_client.mutate_priorities(  # pytype: disable=attribute-error
                table=adders.DEFAULT_PRIORITY_TABLE,
                updates=dict(zip(keys, priorities)))

        # Internalise components, hyperparameters, logger, counter, and methods.
        self._replay_client = replay_client
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=1.)

        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)

        # Initialise and internalise training state (parameters/optimiser state).
        random_key, key_init = jax.random.split(random_key)
        initial_params = unroll.init(key_init, initial_state)
        opt_state = optimizer.init(initial_params)

        state = TrainingState(params=initial_params,
                              target_params=initial_params,
                              opt_state=opt_state,
                              steps=jnp.array(0),
                              random_key=random_key)
        # Replicate parameters.
        self._state = utils.replicate_in_all_devices(state)

        # Shard multiple inputs with on-device prefetching.
        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            num_threads=jax.local_device_count(),
            split_fn=split_sample)
Exemple #26
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 random_key: networks_lib.PRNGKey,
                 loss_fn: losses.Loss,
                 optimizer: optax.GradientTransformation,
                 prefetching_iterator: Iterator[types.Transition],
                 num_sgd_steps_per_step: int,
                 loss_has_aux: bool = False,
                 logger: Optional[loggers.Logger] = None,
                 counter: Optional[counting.Counter] = None):
        """Behavior Cloning Learner.

    Args:
      network: Networks with signature for apply: (params, obs, is_training,
        key) -> jnp.ndarray and for init: (rng, is_training) -> params
      random_key: RNG key.
      loss_fn: BC loss to use.
      optimizer: Optax optimizer.
      prefetching_iterator: A sharded prefetching iterator as outputted from
        `acme.jax.utils.sharded_prefetch`. Please see the documentation for
        `sharded_prefetch` for more details.
      num_sgd_steps_per_step: Number of gradient updates per step.
      loss_has_aux: Whether the loss function returns auxiliary metrics as a
        second argument.
      logger: Logger.
      counter: Counter.
    """
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            loss_and_grad = jax.value_and_grad(loss_fn,
                                               argnums=1,
                                               has_aux=loss_has_aux)

            # Compute losses and their gradients.
            key, key_input = jax.random.split(state.key)
            loss_result, gradients = loss_and_grad(network.apply,
                                                   state.policy_params,
                                                   key_input, transitions)

            # Combine the gradient across all devices (by taking their mean).
            gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME)

            # Compute and combine metrics across all devices.
            metrics = _create_loss_metrics(loss_has_aux, loss_result,
                                           gradients)
            metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME)

            policy_update, optimizer_state = optimizer.update(
                gradients, state.optimizer_state, state.policy_params)
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_update)

            new_state = TrainingState(
                optimizer_state=optimizer_state,
                policy_params=policy_params,
                key=key,
                steps=state.steps + 1,
            )

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter(prefix='learner')
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Split the input batch to `num_sgd_steps_per_step` minibatches in order
        # to achieve better performance on accelerators.
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step)
        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

        random_key, init_key = jax.random.split(random_key)
        policy_params = network.init(init_key)
        optimizer_state = optimizer.init(policy_params)

        # Create initial state.
        state = TrainingState(
            optimizer_state=optimizer_state,
            policy_params=policy_params,
            key=random_key,
            steps=0,
        )
        self._state = utils.replicate_in_all_devices(state)

        self._timestamp = None

        self._prefetching_iterator = prefetching_iterator
Exemple #27
0
  def __init__(self,
               networks: value_dice_networks.ValueDiceNetworks,
               policy_optimizer: optax.GradientTransformation,
               nu_optimizer: optax.GradientTransformation,
               discount: float,
               rng: jnp.ndarray,
               iterator_replay: Iterator[reverb.ReplaySample],
               iterator_demonstrations: Iterator[types.Transition],
               alpha: float = 0.05,
               policy_reg_scale: float = 1e-4,
               nu_reg_scale: float = 10.0,
               num_sgd_steps_per_step: int = 1,
               counter: Optional[counting.Counter] = None,
               logger: Optional[loggers.Logger] = None):

    rng, policy_key, nu_key = jax.random.split(rng, 3)
    policy_init_params = networks.policy_network.init(policy_key)
    policy_optimizer_state = policy_optimizer.init(policy_init_params)

    nu_init_params = networks.nu_network.init(nu_key)
    nu_optimizer_state = nu_optimizer.init(nu_init_params)

    def compute_losses(
        policy_params: networks_lib.Params,
        nu_params: networks_lib.Params,
        key: jnp.ndarray,
        replay_o_tm1: types.NestedArray,
        replay_a_tm1: types.NestedArray,
        replay_o_t: types.NestedArray,
        demo_o_tm1: types.NestedArray,
        demo_a_tm1: types.NestedArray,
        demo_o_t: types.NestedArray,
    ) -> jnp.ndarray:
      # TODO(damienv, hussenot): what to do with the discounts ?

      def policy(obs, key):
        dist_params = networks.policy_network.apply(policy_params, obs)
        return networks.sample(dist_params, key)

      key1, key2, key3, key4 = jax.random.split(key, 4)

      # Predicted actions.
      demo_o_t0 = demo_o_tm1
      policy_demo_a_t0 = policy(demo_o_t0, key1)
      policy_demo_a_t = policy(demo_o_t, key2)
      policy_replay_a_t = policy(replay_o_t, key3)

      replay_a_tm1 = networks.encode_action(replay_a_tm1)
      demo_a_tm1 = networks.encode_action(demo_a_tm1)
      policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0)
      policy_demo_a_t = networks.encode_action(policy_demo_a_t)
      policy_replay_a_t = networks.encode_action(policy_replay_a_t)

      # "Value function" nu over the expert states.
      nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0,
                                             policy_demo_a_t0)
      nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1)
      nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t,
                                            policy_demo_a_t)
      nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t

      # "Value function" nu over the replay buffer states.
      nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1,
                                                replay_a_tm1)
      nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t,
                                              policy_replay_a_t)
      nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t

      # Linear part of the loss.
      linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount))
      linear_loss_rb = jnp.mean(nu_replay_diff)
      linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha)

      # Non linear part of the loss.
      nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff],
                                            axis=0)
      replay_demo_weights = jnp.concatenate([
          jnp.ones_like(nu_demo_diff) * (1 - alpha),
          jnp.ones_like(nu_replay_diff) * alpha
      ],
                                            axis=0)
      replay_demo_weights /= jnp.mean(replay_demo_weights)
      non_linear_loss = jnp.sum(
          jax.lax.stop_gradient(
              utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights,
                                     axis=0)) *
          nu_replay_demo_diff)

      # Final loss.
      loss = (non_linear_loss - linear_loss)

      # Regularized policy loss.
      if policy_reg_scale > 0.:
        policy_reg = _orthogonal_regularization_loss(policy_params)
      else:
        policy_reg = 0.

      # Gradient penality on nu
      if nu_reg_scale > 0.0:
        batch_size = demo_o_tm1.shape[0]
        c = jax.random.uniform(key4, shape=(batch_size,))
        shape_o = [
            dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape)
        ]
        shape_a = [
            dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape)
        ]
        c_o = jnp.reshape(c, shape_o)
        c_a = jnp.reshape(c, shape_a)
        mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1
        mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1
        mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t
        mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t
        mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0)
        mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0)

        def sum_nu(o, a):
          return jnp.sum(networks.nu_network.apply(nu_params, o, a))

        nu_grad_o_fn = jax.grad(sum_nu, argnums=0)
        nu_grad_a_fn = jax.grad(sum_nu, argnums=1)
        nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a)
        nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a)
        nu_grad = jnp.concatenate([
            jnp.reshape(nu_grad_o, [batch_size, -1]),
            jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1)
        # TODO(damienv, hussenot): check for the need of eps
        # (like in the original value dice code).
        nu_grad_penalty = jnp.mean(
            jnp.square(
                jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1))
      else:
        nu_grad_penalty = 0.0

      policy_loss = -loss + policy_reg_scale * policy_reg
      nu_loss = loss + nu_reg_scale * nu_grad_penalty

      return policy_loss, nu_loss

    def sgd_step(
        state: TrainingState,
        data: Tuple[types.Transition, types.Transition]
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      replay_transitions, demo_transitions = data
      key, key_loss = jax.random.split(state.key)
      compute_losses_with_input = functools.partial(
          compute_losses,
          replay_o_tm1=replay_transitions.observation,
          replay_a_tm1=replay_transitions.action,
          replay_o_t=replay_transitions.next_observation,
          demo_o_tm1=demo_transitions.observation,
          demo_a_tm1=demo_transitions.action,
          demo_o_t=demo_transitions.next_observation,
          key=key_loss)
      (policy_loss_value, nu_loss_value), vjpfun = jax.vjp(
          compute_losses_with_input,
          state.policy_params, state.nu_params)
      policy_gradients, _ = vjpfun((1.0, 0.0))
      _, nu_gradients = vjpfun((0.0, 1.0))

      # Update optimizers.
      policy_update, policy_optimizer_state = policy_optimizer.update(
          policy_gradients, state.policy_optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      nu_update, nu_optimizer_state = nu_optimizer.update(
          nu_gradients, state.nu_optimizer_state)
      nu_params = optax.apply_updates(state.nu_params, nu_update)

      new_state = TrainingState(
          policy_optimizer_state=policy_optimizer_state,
          policy_params=policy_params,
          nu_optimizer_state=nu_optimizer_state,
          nu_params=nu_params,
          key=key,
          steps=state.steps + 1,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'nu_loss': nu_loss_value,
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner',
        asynchronous=True,
        serialize_fn=utils.fetch_devicearray,
        steps_key=self._counter.get_steps_key())

    # Iterator on demonstration transitions.
    self._iterator_demonstrations = iterator_demonstrations
    self._iterator_replay = iterator_replay

    self._sgd_step = jax.jit(utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step))

    # Create initial state.
    self._state = TrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_init_params,
        nu_optimizer_state=nu_optimizer_state,
        nu_params=nu_init_params,
        key=rng,
        steps=0,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Exemple #28
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 obs_spec: specs.Array,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialize the network parameters
        dummy_obs = utils.add_batch_dim(utils.zeros_like(obs_spec))
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params, dummy_obs)
        initial_target_params = self.network.init(key_target, dummy_obs)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: Optional[ReverbUpdate]) -> None:
            if reverb_update is None or replay_client is None:
                return
            else:
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates=dict(
                        zip(reverb_update.keys, reverb_update.priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #29
0
  def __init__(
      self,
      preprocessor: processors.Processor,
      sample_network_input: IqnInputs,
      network: parts.Network,
      optimizer: optax.GradientTransformation,
      transition_accumulator: Any,
      replay: replay_lib.TransitionReplay,
      batch_size: int,
      exploration_epsilon: Callable[[int], float],
      min_replay_capacity_fraction: float,
      learn_period: int,
      target_network_update_period: int,
      huber_param: float,
      tau_samples_policy: int,
      tau_samples_s_tm1: int,
      tau_samples_s_t: int,
      rng_key: parts.PRNGKey,
  ):
    self._preprocessor = preprocessor
    self._replay = replay
    self._transition_accumulator = transition_accumulator
    self._batch_size = batch_size
    self._exploration_epsilon = exploration_epsilon
    self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
    self._learn_period = learn_period
    self._target_network_update_period = target_network_update_period

    # Initialize network parameters and optimizer.
    self._rng_key, network_rng_key = jax.random.split(rng_key)
    self._online_params = network.init(
        network_rng_key,
        jax.tree_map(lambda x: x[None, ...], sample_network_input))
    self._target_params = self._online_params
    self._opt_state = optimizer.init(self._online_params)

    # Other agent state: last action, frame count, etc.
    self._action = None
    self._frame_t = -1  # Current frame index.

    # Define jitted loss, update, and policy functions here instead of as
    # class methods, to emphasize that these are meant to be pure functions
    # and should not access the agent object's state via `self`.

    def loss_fn(online_params, target_params, transitions, rng_key):
      """Calculates loss given network parameters and transitions."""
      # Sample tau values for q_tm1, q_t_selector, q_t.
      batch_size = self._batch_size
      rng_key, *sample_keys = jax.random.split(rng_key, 4)
      tau_tm1 = _sample_tau(sample_keys[0], (batch_size, tau_samples_s_tm1))
      tau_t_selector = _sample_tau(sample_keys[1],
                                   (batch_size, tau_samples_policy))
      tau_t = _sample_tau(sample_keys[2], (batch_size, tau_samples_s_t))

      # Compute Q value distributions.
      _, *apply_keys = jax.random.split(rng_key, 4)
      dist_q_tm1 = network.apply(online_params, apply_keys[0],
                                 IqnInputs(transitions.s_tm1, tau_tm1)).q_dist
      dist_q_t_selector = network.apply(
          target_params, apply_keys[1],
          IqnInputs(transitions.s_t, tau_t_selector)).q_dist
      dist_q_target_t = network.apply(target_params, apply_keys[2],
                                      IqnInputs(transitions.s_t, tau_t)).q_dist
      losses = _batch_quantile_q_learning(
          dist_q_tm1,
          tau_tm1,
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          dist_q_t_selector,
          dist_q_target_t,
          huber_param,
      )
      assert losses.shape == (self._batch_size,)
      loss = jnp.mean(losses)
      return loss

    def update(rng_key, opt_state, online_params, target_params, transitions):
      """Computes learning update from batch of replay transitions."""
      rng_key, update_key = jax.random.split(rng_key)
      d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                          transitions, update_key)
      updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
      new_online_params = optax.apply_updates(online_params, updates)
      return rng_key, new_opt_state, new_online_params

    self._update = jax.jit(update)

    def select_action(rng_key, network_params, s_t, exploration_epsilon):
      """Samples action from eps-greedy policy wrt Q-values at given state."""
      rng_key, sample_key, apply_key, policy_key = jax.random.split(rng_key, 4)
      tau_t = _sample_tau(sample_key, (1, tau_samples_policy))
      q_t = network.apply(network_params, apply_key,
                          IqnInputs(s_t[None, ...], tau_t)).q_values[0]
      a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
      return rng_key, a_t

    self._select_action = jax.jit(select_action)
Exemple #30
0
    def __init__(
        self,
        preprocessor: processors.Processor,
        sample_network_input: jnp.ndarray,
        network: parts.Network,
        optimizer: optax.GradientTransformation,
        transition_accumulator: replay_lib.TransitionAccumulator,
        replay: replay_lib.PrioritizedTransitionReplay,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        grad_error_bound: float,
        rng_key: parts.PRNGKey,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.
        self._max_seen_priority = 1.

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, weights,
                    rng_key):
            """Calculates loss given network parameters and transitions."""
            _, *apply_keys = jax.random.split(rng_key, 4)
            q_tm1 = network.apply(online_params, apply_keys[0],
                                  transitions.s_tm1).q_values
            q_t = network.apply(online_params, apply_keys[1],
                                transitions.s_t).q_values
            q_target_t = network.apply(target_params, apply_keys[2],
                                       transitions.s_t).q_values
            td_errors = _batch_double_q_learning(
                q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                q_target_t,
                q_t,
            )
            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size, ) == weights.shape
            # This is not the same as using a huber loss and multiplying by weights.
            loss = jnp.mean(losses * weights)
            return loss, td_errors

        def update(rng_key, opt_state, online_params, target_params,
                   transitions, weights):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params, td_errors = jax.grad(loss_fn, has_aux=True)(
                online_params, target_params, transitions, weights, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params, td_errors

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).q_values[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            return rng_key, a_t

        self._select_action = jax.jit(select_action)