def rnd_update_step( state: RNDTrainingState, transitions: types.Transition, loss_fn: RNDLoss, optimizer: optax.GradientTransformation ) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]: """Run an update steps on the given transitions. Args: state: The learner state. transitions: Transitions to update on. loss_fn: The loss function. optimizer: The optimizer of the predictor network. Returns: A new state and metrics. """ loss, grads = jax.value_and_grad(loss_fn)(state.params, state.target_params, transitions=transitions) update, optimizer_state = optimizer.update(grads, state.optimizer_state) params = optax.apply_updates(state.params, update) new_state = RNDTrainingState( optimizer_state=optimizer_state, params=params, target_params=state.target_params, steps=state.steps + 1, ) return new_state, {'rnd_loss': loss}
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: PolicyValueNet, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, sequence_length: int, discount: float, td_lambda: float, ): # Define loss function. def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" logits, values = network(trajectory.observations) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones_like(td_errors)) return actor_loss + critic_loss # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: TrainingState, trajectory: sequence.Trajectory) -> TrainingState: """Does a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state) # Initialize network parameters and optimiser state. init, forward = hk.without_apply_rng( hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32) initial_params = init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = TrainingState(initial_params, initial_opt_state) self._forward = jax.jit(forward) self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length) self._sgd_step = sgd_step self._rng = rng
def clip_by_global_norm(max_norm) -> GradientTransformation: """Clip updates using their global norm. References: [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) Args: max_norm: the maximum global norm for an update. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ClipByGlobalNormState() def update_fn(updates, state, params=None): del params g_norm = global_norm(updates) trigger = g_norm < max_norm updates = jax.tree_map( lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) return updates, state return GradientTransformation(init_fn, update_fn)
def sgld_gradient_update(step_size_fn, seed, momentum_decay=0., preconditioner=None): """Optax implementation of the SGLD optimizer. If momentum_decay is set to zero, we get the SGLD method [1]. Otherwise, we get the underdamped SGLD (SGHMC) method [2]. Args: step_size_fn: a function taking training step as input and producing the step size as output. seed: int, random seed. momentum_decay: float, momentum decay parameter (default: 0). preconditioner: Preconditioner, an object representing the preconditioner or None; if None, identity preconditioner is used (default: None). [1] "Bayesian Learning via Stochastic Gradient Langevin Dynamics" Max Welling, Yee Whye Teh; ICML 2011 [2] "Stochastic Gradient Hamiltonian Monte Carlo" Tianqi Chen, Emily B. Fox, Carlos Guestrin; ICML 2014 """ if preconditioner is None: preconditioner = get_identity_preconditioner() def init_fn(params): return OptaxSGLDState(count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed), momentum=jax.tree_map(jnp.zeros_like, params), preconditioner_state=preconditioner.init(params)) def update_fn(gradient, state, params=None): del params lr = step_size_fn(state.count) lr_sqrt = jnp.sqrt(lr) noise_std = jnp.sqrt(2 * (1 - momentum_decay)) preconditioner_state = preconditioner.update_preconditioner( gradient, state.preconditioner_state) noise, new_key = tree_utils.normal_like_tree(gradient, state.rng_key) noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state) def update_momentum(m, g, n): return momentum_decay * m + g * lr_sqrt + n * noise_std momentum = jax.tree_map(update_momentum, state.momentum, gradient, noise) updates = preconditioner.multiply_by_m_inv(momentum, preconditioner_state) updates = jax.tree_map(lambda m: m * lr_sqrt, updates) return updates, OptaxSGLDState( count=state.count + 1, rng_key=new_key, momentum=momentum, preconditioner_state=preconditioner_state) return GradientTransformation(init_fn, update_fn)
def __init__(self, direct_rl_learner_factory: Callable[ [Any, Iterator[reverb.ReplaySample]], acme.Learner], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rnd_network: rnd_networks.RNDNetworks, rng_key: jnp.ndarray, grad_updates_per_batch: int, is_sequence_based: bool, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): self._is_sequence_based = is_sequence_based target_key, predictor_key = jax.random.split(rng_key) target_params = rnd_network.target.init(target_key) predictor_params = rnd_network.predictor.init(predictor_key) optimizer_state = optimizer.init(predictor_params) self._state = RNDTrainingState(optimizer_state=optimizer_state, params=predictor_params, target_params=target_params, steps=0) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) loss = functools.partial(rnd_loss, networks=rnd_network) self._update = functools.partial(rnd_update_step, loss_fn=loss, optimizer=optimizer) self._update = utils.process_multiple_batches(self._update, grad_updates_per_batch) self._update = jax.jit(self._update) self._get_reward = jax.jit( functools.partial(rnd_networks.compute_rnd_reward, networks=rnd_network)) # Generator expression that works the same as an iterator. # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions updated_iterator = (self._process_sample(sample) for sample in iterator) self._direct_rl_learner = direct_rl_learner_factory( rnd_network.direct_rl_networks, updated_iterator) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: hk.Transformed, obs_spec: specs.Array, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, dataset: tf.data.Dataset, loss_fn: LossFn = _sparse_categorical_cross_entropy, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: # Pull out the data needed for updates. o_tm1, a_tm1, r_t, d_t, o_t = sample.data del r_t, d_t, o_t logits = network.apply(params, o_tm1) return jnp.mean(loss_fn(a_tm1, logits)) def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]: """Do a step of SGD.""" grad_fn = jax.value_and_grad(loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 new_state = TrainingState( params=new_params, opt_state=new_opt_state, steps=steps) # Compute the global norm of the gradients for logging. global_gradient_norm = optax.global_norm(gradients) fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm} return new_state, fetches self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState( params=initial_params, opt_state=initial_opt_state, steps=0) self._sgd_step = jax.jit(sgd_step)
def ail_update_step( state: DiscriminatorTrainingState, data: Tuple[types.Transition, types.Transition], optimizer: optax.GradientTransformation, ail_network: ail_networks.AILNetworks, loss_fn: losses.Loss ) -> Tuple[DiscriminatorTrainingState, losses.Metrics]: """Run an update steps on the given transitions. Args: state: The learner state. data: Demo and rb transitions. optimizer: Discriminator optimizer. ail_network: AIL networks. loss_fn: Discriminator loss to minimize. Returns: A new state and metrics. """ demo_transitions, rb_transitions = data key, discriminator_key, loss_key = jax.random.split(state.key, 3) def compute_loss( discriminator_params: networks_lib.Params) -> losses.LossOutput: discriminator_fn = functools.partial( ail_network.discriminator_network.apply, discriminator_params, state.policy_params, is_training=True, rng=discriminator_key) return loss_fn(discriminator_fn, state.discriminator_state, demo_transitions, rb_transitions, loss_key) loss_grad = jax.grad(compute_loss, has_aux=True) grads, (loss, new_discriminator_state) = loss_grad(state.discriminator_params) update, optimizer_state = optimizer.update( grads, state.optimizer_state, params=state.discriminator_params) discriminator_params = optax.apply_updates(state.discriminator_params, update) new_state = DiscriminatorTrainingState( optimizer_state=optimizer_state, discriminator_params=discriminator_params, discriminator_state=new_discriminator_state, policy_params=state.policy_params, # Not modified. key=key, steps=state.steps + 1, ) return new_state, loss
def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation: """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc) Args: weight_decay: a scalar weight decay rate. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AdditiveWeightDecayState() def update_fn(updates, state, params): updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params) return updates, state return GradientTransformation(init_fn, update_fn)
def fit(params: optax.Params, opt: optax.GradientTransformation) -> optax.Params: opt_state = opt.init(params) @jax.jit def step(params, opt_state, batch, labels): (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(params, batch, labels) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss_val, accuracy for i, (batch, labels) in enumerate(zip(train_data, train_labels)): params, opt_state, loss_val, accuracy = step(params, opt_state, batch, labels) if i % 100 == 0: print( f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%" ) return params
def __init__(self, counter: counting.Counter, direct_rl_learner_factory: Callable[ [Iterator[reverb.ReplaySample]], acme.Learner], loss_fn: losses.Loss, iterator: Iterator[AILSample], discriminator_optimizer: optax.GradientTransformation, ail_network: ail_networks.AILNetworks, discriminator_key: networks_lib.PRNGKey, is_sequence_based: bool, num_sgd_steps_per_step: int = 1, policy_variable_name: Optional[str] = None, logger: Optional[loggers.Logger] = None): """AIL Learner. Args: counter: Counter. direct_rl_learner_factory: Function that creates the direct RL learner when passed a replay sample iterator. loss_fn: Discriminator loss. iterator: Iterator that returns AILSamples. discriminator_optimizer: Discriminator optax optimizer. ail_network: AIL networks. discriminator_key: RNG key. is_sequence_based: If True, a direct rl algorithm is using SequenceAdder data format. Otherwise the learner assumes that the direct rl algorithm is using NStepTransitionAdder. num_sgd_steps_per_step: Number of discriminator gradient updates per step. policy_variable_name: The name of the policy variable to retrieve direct_rl policy parameters. logger: Logger. """ self._is_sequence_based = is_sequence_based state_key, networks_key = jax.random.split(discriminator_key) # Generator expression that works the same as an iterator. # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions iterator, direct_rl_iterator = itertools.tee(iterator) direct_rl_iterator = (self._process_sample(sample.direct_sample) for sample in direct_rl_iterator) self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator) self._iterator = iterator if policy_variable_name is not None: def get_policy_params(): return self._direct_rl_learner.get_variables( [policy_variable_name])[0] self._get_policy_params = get_policy_params else: self._get_policy_params = lambda: None # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Use the JIT compiler. self._update_step = functools.partial( ail_update_step, optimizer=discriminator_optimizer, ail_network=ail_network, loss_fn=loss_fn) self._update_step = utils.process_multiple_batches( self._update_step, num_sgd_steps_per_step) self._update_step = jax.jit(self._update_step) discriminator_params, discriminator_state = ( ail_network.discriminator_network.init(networks_key)) self._state = DiscriminatorTrainingState( optimizer_state=discriminator_optimizer.init(discriminator_params), discriminator_params=discriminator_params, discriminator_state=discriminator_state, policy_params=self._get_policy_params(), key=state_key, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None self._get_reward = jax.jit( functools.partial(ail_networks.compute_ail_reward, networks=ail_network))
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, shaping_function, mask_probability: float, num_heads: int, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, grad_error_bound: float, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._mask_probabilities = jnp.array( [mask_probability, 1 - mask_probability]) self._num_heads = num_heads self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key, shaping_key = jax.random.split( rng_key, 4) q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).multi_head_output q_target_t = network.apply(target_params, target_key, transitions.s_t).multi_head_output # batch by num_heads -> batch by num_heads by num_actions mask = jnp.einsum('ij,k->ijk', transitions.mask_t, jnp.ones(q_tm1.shape[-1])) masked_q = jnp.multiply(mask, q_tm1) masked_q_target = jnp.multiply(mask, q_target_t) flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1])) flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1])) # compute shaping function F(s, a, s') shaped_rewards = shaping_function(q_target_t, transitions, shaping_key) repeated_actions = jnp.repeat(transitions.a_tm1, num_heads) repeated_rewards = jnp.repeat(shaped_rewards, num_heads) repeated_discounts = jnp.repeat(transitions.discount_t, num_heads) td_errors = _batch_q_learning( flattened_q, repeated_actions, repeated_rewards, repeated_discounts, flattened_q_target, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size * num_heads, ) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).random_head_q_value[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)
def __init__(self, networks: td3_networks.TD3Networks, random_key: networks_lib.PRNGKey, discount: float, iterator: Iterator[reverb.ReplaySample], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, twin_critic_optimizer: optax.GradientTransformation, delay: int = 2, target_sigma: float = 0.2, noise_clip: float = 0.5, tau: float = 0.005, use_sarsa_target: bool = False, bc_alpha: Optional[float] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initializes the TD3 learner. Args: networks: TD3 networks. random_key: a key for random number generation. discount: discount to use for TD updates iterator: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. twin_critic_optimizer: the twin Q-function optimizer. delay: ratio of policy updates for critic updates (see TD3), delay=2 means 2 updates of the critic for 1 policy update. target_sigma: std of zero mean Gaussian added to the action of the next_state, for critic evaluation (reducing overestimation bias). noise_clip: hard constraint on target noise. tau: target parameters smoothing coefficient. use_sarsa_target: compute on-policy target using iterator's actions rather than sampled actions. Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). When set to `True`, `target_policy_params` are unused. This is only working when the learner is used as an offline algorithm. I.e. TD3Builder does not support adding the SARSA target to the replay buffer. bc_alpha: bc_alpha: Implements TD3+BC. See comments in TD3Config.bc_alpha for details. counter: counter object used to keep track of steps. logger: logger object to be used by learner. num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ def policy_loss( policy_params: networks_lib.Params, critic_params: networks_lib.Params, transition: types.NestedArray, ) -> jnp.ndarray: # Computes the discrete policy gradient loss. action = networks.policy_network.apply(policy_params, transition.observation) grad_critic = jax.vmap(jax.grad(networks.critic_network.apply, argnums=2), in_axes=(None, 0, 0)) dq_da = grad_critic(critic_params, transition.observation, action) batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) loss = jnp.mean(batch_dpg_learning(action, dq_da)) if bc_alpha is not None: # BC regularization for offline RL q_sa = networks.critic_network.apply(critic_params, transition.observation, action) bc_factor = jax.lax.stop_gradient(bc_alpha / jnp.mean(jnp.abs(q_sa))) loss += jnp.mean( jnp.square(action - transition.action)) / bc_factor return loss def critic_loss( critic_params: networks_lib.Params, state: TrainingState, transition: types.Transition, random_key: jnp.ndarray, ): # Computes the critic loss. q_tm1 = networks.critic_network.apply(critic_params, transition.observation, transition.action) if use_sarsa_target: # TODO(b/222674779): use N-steps Trajectories to get the next actions. assert 'next_action' in transition.extras, ( 'next actions should be given as extras for one step RL.') action = transition.extras['next_action'] else: action = networks.policy_network.apply( state.target_policy_params, transition.next_observation) action = networks.add_policy_noise(action, random_key, target_sigma, noise_clip) q_t = networks.critic_network.apply(state.target_critic_params, transition.next_observation, action) twin_q_t = networks.twin_critic_network.apply( state.target_twin_critic_params, transition.next_observation, action) q_t = jnp.minimum(q_t, twin_q_t) target_q_tm1 = transition.reward + discount * transition.discount * q_t td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1 return jnp.mean(jnp.square(td_error)) def update_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: random_key, key_critic, key_twin = jax.random.split( state.random_key, 3) # Updates on the critic: compute the gradients, and update using # Polyak averaging. critic_loss_and_grad = jax.value_and_grad(critic_loss) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions, key_critic) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) critic_params = optax.apply_updates(state.critic_params, critic_updates) # In the original authors' implementation the critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_critic_params = optax.incremental_update( new_tensors=critic_params, old_tensors=state.target_critic_params, step_size=tau) # Updates on the twin critic: compute the gradients, and update using # Polyak averaging. twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( state.twin_critic_params, state, transitions, key_twin) twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( twin_critic_gradients, state.twin_critic_opt_state) twin_critic_params = optax.apply_updates(state.twin_critic_params, twin_critic_updates) # In the original authors' implementation the twin critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_twin_critic_params = optax.incremental_update( new_tensors=twin_critic_params, old_tensors=state.target_twin_critic_params, step_size=tau) # Updates on the policy: compute the gradients, and update using # Polyak averaging (if delay enabled, the update might not be applied). policy_loss_and_grad = jax.value_and_grad(policy_loss) policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions) def update_policy_step(): policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) policy_params = optax.apply_updates(state.policy_params, policy_updates) target_policy_params = optax.incremental_update( new_tensors=policy_params, old_tensors=state.target_policy_params, step_size=tau) return policy_params, target_policy_params, policy_opt_state # The update on the policy is applied every `delay` steps. current_policy_state = (state.policy_params, state.target_policy_params, state.policy_opt_state) policy_params, target_policy_params, policy_opt_state = jax.lax.cond( state.steps % delay == 0, lambda _: update_policy_step(), lambda _: current_policy_state, operand=None) steps = state.steps + 1 new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, twin_critic_params=twin_critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, target_twin_critic_params=target_twin_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, twin_critic_opt_state=twin_critic_opt_state, steps=steps, random_key=random_key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, 'twin_critic_loss': twin_critic_loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Create prefetching dataset iterator. self._iterator = iterator # Faster sgd step update_step = utils.process_multiple_batches(update_step, num_sgd_steps_per_step) # Use the JIT compiler. self._update_step = jax.jit(update_step) (key_init_policy, key_init_twin, key_init_target, key_state) = jax.random.split(random_key, 4) # Create the network parameters and copy into the target network parameters. initial_policy_params = networks.policy_network.init(key_init_policy) initial_critic_params = networks.critic_network.init(key_init_twin) initial_twin_critic_params = networks.twin_critic_network.init( key_init_target) initial_target_policy_params = initial_policy_params initial_target_critic_params = initial_critic_params initial_target_twin_critic_params = initial_twin_critic_params # Initialize optimizers. initial_policy_opt_state = policy_optimizer.init(initial_policy_params) initial_critic_opt_state = critic_optimizer.init(initial_critic_params) initial_twin_critic_opt_state = twin_critic_optimizer.init( initial_twin_critic_params) # Create initial state. self._state = TrainingState( policy_params=initial_policy_params, target_policy_params=initial_target_policy_params, critic_params=initial_critic_params, twin_critic_params=initial_twin_critic_params, target_critic_params=initial_target_critic_params, target_twin_critic_params=initial_target_twin_critic_params, policy_opt_state=initial_policy_opt_state, critic_opt_state=initial_critic_opt_state, twin_critic_opt_state=initial_twin_critic_opt_state, steps=0, random_key=key_state) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, observation_spec: specs.Array, action_spec: specs.DiscreteArray, network: PolicyValueNet, initial_rnn_state: RNNState, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, buffer_length: int, discount: float, td_lambda: float, entropy_cost: float = 1., critic_cost: float = 1.): @jax.jit def pack(trajectory: buffer.Trajectory) -> List[jnp.ndarray]: """Converts a trajectory into an input.""" observations = trajectory.observations[:, None, ...] rewards = jnp.concatenate([ trajectory.previous_reward, jnp.squeeze(trajectory.rewards, -1) ], -1) rewards = jnp.squeeze(rewards) rewards = jnp.expand_dims(rewards, (1, 2)) previous_action = jax.nn.one_hot(trajectory.previous_action, action_spec.num_values) actions = jax.nn.one_hot(jnp.squeeze(trajectory.actions, 1), action_spec.num_values) actions = jnp.expand_dims( jnp.concatenate([previous_action, actions], 0), 1) return [observations, rewards, actions] @jax.jit def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState): """"Computes a linear combination of the policy gradient loss and value loss and regularizes it with an entropy term.""" inputs = pack(trajectory) # Dyanmically unroll the network. This Haiku utility function unpacks the # list of input tensors such that the i^{th} row from each input tensor # is presented to the i^{th} unrolled RNN module. (logits, values, _, _, state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll( network, inputs, rnn_unroll_state) trajectory_len = trajectory.actions.shape[0] # Compute the combined loss given the output of the model. td_errors = rlax.td_lambda(v_tm1=values[:-1, 0], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones(trajectory_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len))) combined_loss = (actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss) return combined_loss, new_rnn_unroll_state # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: AgentState, trajectory: buffer.Trajectory) -> AgentState: """Performs a step of SGD over a trajectory.""" gradients, new_rnn_state = jax.grad(loss_fn, has_aux=True)( state.params, trajectory, state.rnn_unroll_state) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return state._replace(params=new_params, opt_state=new_opt_state, rnn_unroll_state=new_rnn_state) # Initialize network parameters. init, forward = hk.without_apply_rng( hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *observation_spec.shape), dtype=observation_spec.dtype) dummy_reward = jnp.zeros((1, 1, 1)) dummy_action = jnp.zeros((1, 1, action_spec.num_values)) inputs = [dummy_observation, dummy_reward, dummy_action] initial_params = init(next(rng), inputs, initial_rnn_state) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = AgentState(initial_params, initial_opt_state, initial_rnn_state, initial_rnn_state) self._forward = jax.jit(forward) self._buffer = buffer.Buffer(observation_spec, action_spec, buffer_length) self._sgd_step = sgd_step self._rng = rng self._initial_rnn_state = initial_rnn_state self._action_spec = action_spec
def __init__(self, networks: CRRNetworks, random_key: networks_lib.PRNGKey, discount: float, target_update_period: int, policy_loss_coeff_fn: PolicyLossCoeff, iterator: Iterator[types.Transition], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, grad_updates_per_batch: int = 1, use_sarsa_target: bool = False): """Initializes the CRR learner. Args: networks: CRR networks. random_key: a key for random number generation. discount: discount to use for TD updates. target_update_period: period to update target's parameters. policy_loss_coeff_fn: set the loss function for the policy. iterator: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. counter: counter object used to keep track of steps. logger: logger object to be used by learner. grad_updates_per_batch: how many gradient updates given a sampled batch. use_sarsa_target: compute on-policy target using iterator's actions rather than sampled actions. Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). When set to `True`, `target_policy_params` are unused. """ critic_network = networks.critic_network policy_network = networks.policy_network def policy_loss( policy_params: networks_lib.Params, critic_params: networks_lib.Params, transition: types.Transition, key: networks_lib.PRNGKey, ) -> jnp.ndarray: # Compute the loss coefficients. coeff = policy_loss_coeff_fn(networks, policy_params, critic_params, transition, key) coeff = jax.lax.stop_gradient(coeff) # Return the weighted loss. dist_params = policy_network.apply(policy_params, transition.observation) logp_action = networks.log_prob(dist_params, transition.action) return -jnp.mean(logp_action * coeff) def critic_loss( critic_params: networks_lib.Params, target_policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, transition: types.Transition, key: networks_lib.PRNGKey, ): # Sample the next action. if use_sarsa_target: # TODO(b/222674779): use N-steps Trajectories to get the next actions. assert 'next_action' in transition.extras, ( 'next actions should be given as extras for one step RL.') next_action = transition.extras['next_action'] else: next_dist_params = policy_network.apply( target_policy_params, transition.next_observation) next_action = networks.sample(next_dist_params, key) # Calculate the value of the next state and action. next_q = critic_network.apply(target_critic_params, transition.next_observation, next_action) target_q = transition.reward + transition.discount * discount * next_q target_q = jax.lax.stop_gradient(target_q) q = critic_network.apply(critic_params, transition.observation, transition.action) q_error = q - target_q # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. # TODO(sertan): Replace with a distributional critic. CRR paper states # that this may perform better. return 0.5 * jnp.mean(jnp.square(q_error)) policy_loss_and_grad = jax.value_and_grad(policy_loss) critic_loss_and_grad = jax.value_and_grad(critic_loss) def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_policy, key_critic = jax.random.split(state.key, 3) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions, key_policy) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state.target_policy_params, state.target_critic_params, transitions, key_critic) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = optax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, target_update_period) new_state = TrainingState( policy_params=policy_params, target_policy_params=target_policy_params, critic_params=critic_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, key=key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch) self._sgd_step = jax.jit(sgd_step) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Create prefetching dataset iterator. self._iterator = iterator # Create the network parameters and copy into the target network parameters. key, key_policy, key_critic = jax.random.split(random_key, 3) initial_policy_params = policy_network.init(key_policy) initial_critic_params = critic_network.init(key_critic) initial_target_policy_params = initial_policy_params initial_target_critic_params = initial_critic_params # Initialize optimizers. initial_policy_opt_state = policy_optimizer.init(initial_policy_params) initial_critic_opt_state = critic_optimizer.init(initial_critic_params) # Create initial state. self._state = TrainingState( policy_params=initial_policy_params, target_policy_params=initial_target_policy_params, critic_params=initial_critic_params, target_critic_params=initial_target_critic_params, policy_opt_state=initial_policy_opt_state, critic_opt_state=initial_critic_opt_state, steps=0, key=key, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, batch_size: int, networks: CQLNetworks, random_key: networks_lib.PRNGKey, demonstrations: Iterator[types.Transition], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, tau: float = 0.005, fixed_cql_coefficient: Optional[float] = None, cql_lagrange_threshold: Optional[float] = None, cql_num_samples: int = 10, num_sgd_steps_per_step: int = 1, reward_scale: float = 1.0, discount: float = 0.99, fixed_entropy_coefficient: Optional[float] = None, target_entropy: Optional[float] = 0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the CQL learner. Args: batch_size: bath size. networks: CQL networks. random_key: a key for random number generation. demonstrations: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. tau: target smoothing coefficient. fixed_cql_coefficient: the value for cql coefficient. If None, an adaptive coefficient will be used. cql_lagrange_threshold: a threshold that controls the adaptive loss for the cql coefficient. cql_num_samples: number of samples used to compute logsumexp(Q) via importance sampling. num_sgd_steps_per_step: how many gradient updated to perform per batch. batch is split into this many smaller batches, thus should be a multiple of num_sgd_steps_per_step reward_scale: reward scale. discount: discount to use for TD updates. fixed_entropy_coefficient: coefficient applied to the entropy bonus. If None, an adaptative coefficient will be used. target_entropy: Target entropy when using adapdative entropy bonus. counter: counter object used to keep track of steps. logger: logger object to be used by learner. """ adaptive_entropy_coefficient = fixed_entropy_coefficient is None action_spec = networks.environment_specs.actions if adaptive_entropy_coefficient: # sac_alpha is the temperature parameter that determines the relative # importance of the entropy term versus the reward. log_sac_alpha = jnp.asarray(0., dtype=jnp.float32) alpha_optimizer = optax.adam(learning_rate=3e-4) alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha) else: if target_entropy: raise ValueError('target_entropy should not be set when ' 'fixed_entropy_coefficient is provided') adaptive_cql_coefficient = fixed_cql_coefficient is None if adaptive_cql_coefficient: log_cql_alpha = jnp.asarray(0., dtype=jnp.float32) cql_optimizer = optax.adam(learning_rate=3e-4) cql_optimizer_state = cql_optimizer.init(log_cql_alpha) else: if cql_lagrange_threshold: raise ValueError( 'cql_lagrange_threshold should not be set when ' 'fixed_cql_coefficient is provided') def alpha_loss(log_sac_alpha: jnp.ndarray, policy_params: networks_lib.Params, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) sac_alpha = jnp.exp(log_sac_alpha) sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob - target_entropy) return jnp.mean(sac_alpha_loss) def sac_critic_loss(q_old_action: jnp.ndarray, policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the SAC part of the loss.""" next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_action = networks.sample(next_dist_params, key) next_log_prob = networks.log_prob(next_dist_params, next_action) next_q = networks.critic_network.apply( target_critic_params, transitions.next_observation, next_action) next_v = jnp.min(next_q, axis=-1) - sac_alpha * next_log_prob target_q = jax.lax.stop_gradient( transitions.reward * reward_scale + transitions.discount * discount * next_v) return jnp.mean( jnp.square(q_old_action - jnp.expand_dims(target_q, -1))) def batched_critic(actions: jnp.ndarray, critic_params: networks_lib.Params, observation: jnp.ndarray) -> jnp.ndarray: """Applies the critic network to a batch of sampled actions.""" actions = jax.lax.stop_gradient(actions) tiled_actions = jnp.reshape(actions, (batch_size * cql_num_samples, -1)) tiled_states = jnp.tile(observation, [cql_num_samples, 1]) tiled_q = networks.critic_network.apply(critic_params, tiled_states, tiled_actions) return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1)) def cql_critic_loss(q_old_action: jnp.ndarray, critic_params: networks_lib.Params, policy_params: networks_lib.Params, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the CQL part of the loss.""" # The CQL part of the loss is # logsumexp(Q(s,·)) - Q(s,a), # where s is the currrent state, and a the action in the dataset (so # Q(s,a) is simply q_old_action. # We need to estimate logsumexp(Q). This is done with importance sampling # (IS). This function implements the unlabeled equation page 29, Appx. F, # in https://arxiv.org/abs/2006.04779. # Here, IS is done with the uniform distribution and the policy in the # current state s. In their implementation, the authors also add the # policy in the transiting state s': # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py, # (l. 233-236). key_policy, key_policy_next, key_uniform = jax.random.split(key, 3) def sampled_q(obs, key): actions, log_probs = apply_and_sample_n( key, networks, policy_params, obs, cql_num_samples) return batched_critic( actions, critic_params, transitions.observation) - jax.lax.stop_gradient( jnp.expand_dims(log_probs, -1)) # Sample wrt policy in s sampled_q_from_policy = sampled_q(transitions.observation, key_policy) # Sample wrt policy in s' sampled_q_from_policy_next = sampled_q( transitions.next_observation, key_policy_next) # Sample wrt uniform actions_uniform = jax.random.uniform( key_uniform, (cql_num_samples, batch_size) + action_spec.shape, minval=action_spec.minimum, maxval=action_spec.maximum) log_prob_uniform = -jnp.sum( jnp.log(action_spec.maximum - action_spec.minimum)) sampled_q_from_uniform = (batched_critic( actions_uniform, critic_params, transitions.observation) - log_prob_uniform) # Combine the samplings combined = jnp.concatenate( (sampled_q_from_uniform, sampled_q_from_policy, sampled_q_from_policy_next), axis=0) lse_q = jax.nn.logsumexp(combined, axis=0, b=1. / (3 * cql_num_samples)) return jnp.mean(lse_q - q_old_action) def critic_loss(critic_params: networks_lib.Params, policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, cql_alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the full critic loss.""" key_cql, key_sac = jax.random.split(key, 2) q_old_action = networks.critic_network.apply( critic_params, transitions.observation, transitions.action) cql_loss = cql_critic_loss(q_old_action, critic_params, policy_params, transitions, key_cql) sac_loss = sac_critic_loss(q_old_action, policy_params, target_critic_params, sac_alpha, transitions, key_sac) return cql_alpha * cql_loss + sac_loss def cql_lagrange_loss(log_cql_alpha: jnp.ndarray, critic_params: networks_lib.Params, policy_params: networks_lib.Params, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Computes the loss that optimizes the cql coefficient.""" cql_alpha = jnp.exp(log_cql_alpha) q_old_action = networks.critic_network.apply( critic_params, transitions.observation, transitions.action) return -cql_alpha * (cql_critic_loss( q_old_action, critic_params, policy_params, transitions, key) - cql_lagrange_threshold) def actor_loss(policy_params: networks_lib.Params, critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Computes the loss for the policy.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.critic_network.apply(critic_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) return jnp.mean(sac_alpha * log_prob - min_q) alpha_grad = jax.value_and_grad(alpha_loss) cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss) critic_grad = jax.value_and_grad(critic_loss) actor_grad = jax.value_and_grad(actor_loss) def update_step( state: TrainingState, rb_transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha, state.policy_params, rb_transitions, key_alpha) sac_alpha = jnp.exp(state.log_sac_alpha) else: sac_alpha = fixed_entropy_coefficient if adaptive_cql_coefficient: cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad( state.log_cql_alpha, state.critic_params, state.policy_params, rb_transitions, key_critic) cql_lagrange_grads = jnp.clip(cql_lagrange_grads, -_CQL_GRAD_CLIPPING_VALUE, _CQL_GRAD_CLIPPING_VALUE) cql_alpha = jnp.exp(state.log_cql_alpha) cql_alpha = jnp.clip(cql_alpha, a_min=0., a_max=_CQL_COEFFICIENT_MAX_VALUE) else: cql_alpha = fixed_cql_coefficient critic_loss, critic_grads = critic_grad(state.critic_params, state.policy_params, state.target_critic_params, sac_alpha, cql_alpha, rb_transitions, key_critic) actor_loss, actor_grads = actor_grad(state.policy_params, state.critic_params, sac_alpha, rb_transitions, key_actor) # Apply policy gradients actor_update, policy_optimizer_state = policy_optimizer.update( actor_grads, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, actor_update) # Apply critic gradients critic_update, critic_optimizer_state = critic_optimizer.update( critic_grads, state.critic_optimizer_state) critic_params = optax.apply_updates(state.critic_params, critic_update) new_target_critic_params = jax.tree_multimap( lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params, critic_params) metrics = { 'critic_loss': critic_loss, 'actor_loss': actor_loss, } new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, critic_optimizer_state=critic_optimizer_state, policy_params=policy_params, critic_params=critic_params, target_critic_params=new_target_critic_params, key=key, steps=state.steps + 1, ) if adaptive_entropy_coefficient: # Apply sac_alpha gradients alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) log_sac_alpha = optax.apply_updates(state.log_sac_alpha, alpha_update) metrics.update({ 'alpha_loss': alpha_loss, 'sac_alpha': jnp.exp(log_sac_alpha), }) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, log_sac_alpha=log_sac_alpha) if adaptive_cql_coefficient: # Apply cql coeff gradients cql_update, cql_optimizer_state = cql_optimizer.update( cql_lagrange_grads, state.cql_optimizer_state) log_cql_alpha = optax.apply_updates(state.log_cql_alpha, cql_update) metrics.update({ 'cql_lagrange_loss': cql_lagrange_loss, 'cql_alpha': jnp.exp(log_cql_alpha), }) new_state = new_state._replace( cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._demonstrations = demonstrations # Use the JIT compiler. self._update_step = utils.process_multiple_batches( update_step, num_sgd_steps_per_step) self._update_step = jax.jit(self._update_step) # Create initial state. key_policy, key_q, training_state_key = jax.random.split(random_key, 3) del random_key policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) critic_params = networks.critic_network.init(key_q) critic_optimizer_state = critic_optimizer.init(critic_params) self._state = TrainingState( policy_optimizer_state=policy_optimizer_state, critic_optimizer_state=critic_optimizer_state, policy_params=policy_params, critic_params=critic_params, target_critic_params=critic_params, key=training_state_key, steps=0) if adaptive_entropy_coefficient: self._state = self._state._replace( alpha_optimizer_state=alpha_optimizer_state, log_sac_alpha=log_sac_alpha) if adaptive_cql_coefficient: self._state = self._state._replace( cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: hk.Transformed, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities) def sgd_step( state: TrainingState, samples: reverb.ReplaySample) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 # Periodically update target networks. target_params = rlax.periodic_update( new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState( params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs def update_priorities(outputs: LearnerOutputs): replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(outputs.keys, outputs.priorities))) # Internalise agent components (replay buffer, networks, optimizer). self._replay_client = replay_client self._iterator = utils.prefetch(iterator) # Internalise the hyperparameters. self._target_update_period = target_update_period # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_target_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, steps=0) self._forward = jax.jit(network.apply) self._sgd_step = jax.jit(sgd_step) self._async_priority_updater = async_utils.AsyncExecutor(update_priorities)
def train(network_def: nn.Module, optim: optax.GradientTransformation, alpha_optim: optax.GradientTransformation, optimizer_state: jnp.ndarray, alpha_optimizer_state: jnp.ndarray, network_params: flax.core.FrozenDict, target_params: flax.core.FrozenDict, log_alpha: jnp.ndarray, key: jnp.ndarray, states: jnp.ndarray, actions: jnp.ndarray, next_states: jnp.ndarray, rewards: jnp.ndarray, terminals: jnp.ndarray, cumulative_gamma: float, target_entropy: float, reward_scale_factor: float) -> Mapping[str, Any]: """Run the training step. Returns a list of updated values and losses. Args: network_def: The SAC network definition. optim: The SAC optimizer (which also wraps the SAC parameters). alpha_optim: The optimizer for alpha. optimizer_state: The SAC optimizer state. alpha_optimizer_state: The alpha optimizer state. network_params: Parameters for SAC's online network. target_params: The parameters for SAC's target network. log_alpha: Parameters for alpha network. key: An rng key to use for random action selection. states: A batch of states. actions: A batch of actions. next_states: A batch of next states. rewards: A batch of rewards. terminals: A batch of terminals. cumulative_gamma: The discount factor to use. target_entropy: The target entropy for the agent. reward_scale_factor: A factor by which to scale rewards. Returns: A mapping from string keys to values, including updated optimizers and training statistics. """ # Get the models from all the optimizers. frozen_params = network_params # For use in loss_fn without apply gradients batch_size = states.shape[0] actions = jnp.reshape(actions, (batch_size, -1)) # Flatten def loss_fn( params: flax.core.FrozenDict, log_alpha: flax.core.FrozenDict, state: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray, next_state: jnp.ndarray, terminal: jnp.ndarray, rng: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Calculates the loss for one transition. Args: params: Parameters for the SAC network. log_alpha: SAC's log_alpha parameter. state: A single state vector. action: A single action vector. reward: A reward scalar. next_state: A next state vector. terminal: A terminal scalar. rng: An RNG key to use for sampling actions. Returns: A tuple containing 1) the combined SAC loss and 2) a mapping containing statistics from the loss step. """ rng1, rng2 = jax.random.split(rng, 2) # J_Q(\theta) from equation (5) in paper. q_value_1, q_value_2 = network_def.apply( params, state, action, method=network_def.critic) q_value_1 = jnp.squeeze(q_value_1) q_value_2 = jnp.squeeze(q_value_2) target_outputs = network_def.apply(target_params, next_state, rng1, True) target_q_value_1, target_q_value_2 = target_outputs.critic target_q_value = jnp.squeeze( jnp.minimum(target_q_value_1, target_q_value_2)) alpha_value = jnp.exp(log_alpha) log_prob = target_outputs.actor.log_probability target = reward_scale_factor * reward + cumulative_gamma * ( target_q_value - alpha_value * log_prob) * (1. - terminal) target = jax.lax.stop_gradient(target) critic_loss_1 = losses.mse_loss(q_value_1, target) critic_loss_2 = losses.mse_loss(q_value_2, target) critic_loss = jnp.mean(critic_loss_1 + critic_loss_2) # J_{\pi}(\phi) from equation (9) in paper. mean_action, sampled_action, action_log_prob = network_def.apply( params, state, rng2, method=network_def.actor) # We use frozen_params so that gradients can flow back to the actor without # being used to update the critic. q_value_no_grad_1, q_value_no_grad_2 = network_def.apply( frozen_params, state, sampled_action, method=network_def.critic) no_grad_q_value = jnp.squeeze( jnp.minimum(q_value_no_grad_1, q_value_no_grad_2)) alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha)) policy_loss = jnp.mean(alpha_value * action_log_prob - no_grad_q_value) # J(\alpha) from equation (18) in paper. entropy_diff = -action_log_prob - target_entropy alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diff)) # Giving a smaller weight to the critic empirically gives better results combined_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss return combined_loss, { 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'alpha_loss': alpha_loss, 'critic_value_1': q_value_1, 'critic_value_2': q_value_2, 'target_value_1': target_q_value_1, 'target_value_2': target_q_value_2, 'mean_action': mean_action } grad_fn = jax.vmap( jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True), in_axes=(None, None, 0, 0, 0, 0, 0, 0)) rng = jnp.stack(jax.random.split(key, num=batch_size)) (_, aux_vars), gradients = grad_fn(network_params, log_alpha, states, actions, rewards, next_states, terminals, rng) # This calculates the mean gradient/aux_vars using the individual # gradients/aux_vars from each item in the batch. gradients = jax.tree_map(functools.partial(jnp.mean, axis=0), gradients) aux_vars = jax.tree_map(functools.partial(jnp.mean, axis=0), aux_vars) network_gradient, alpha_gradient = gradients # Apply gradients to all the optimizers. updates, optimizer_state = optim.update(network_gradient, optimizer_state, params=network_params) network_params = optax.apply_updates(network_params, updates) alpha_updates, alpha_optimizer_state = alpha_optim.update( alpha_gradient, alpha_optimizer_state, params=log_alpha) log_alpha = optax.apply_updates(log_alpha, alpha_updates) # Compile everything in a dict. returns = { 'network_params': network_params, 'log_alpha': log_alpha, 'optimizer_state': optimizer_state, 'alpha_optimizer_state': alpha_optimizer_state, 'Losses/Critic': aux_vars['critic_loss'], 'Losses/Actor': aux_vars['policy_loss'], 'Losses/Alpha': aux_vars['alpha_loss'], 'Values/CriticValues1': jnp.mean(aux_vars['critic_value_1']), 'Values/CriticValues2': jnp.mean(aux_vars['critic_value_2']), 'Values/TargetValues1': jnp.mean(aux_vars['target_value_1']), 'Values/TargetValues2': jnp.mean(aux_vars['target_value_2']), 'Values/Alpha': jnp.exp(log_alpha), } for i, a in enumerate(aux_vars['mean_action']): returns.update({f'Values/MeanActions{i}': a}) return returns
def __init__(self, network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, loss_fn: losses.Loss, optimizer: optax.GradientTransformation, demonstrations: Iterator[types.Transition], num_sgd_steps_per_step: int, logger: Optional[loggers.Logger] = None, counter: Optional[counting.Counter] = None): def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_value, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) policy_update, optimizer_state = optimizer.update(gradients, state.optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) metrics = { 'loss': loss_value, 'gradient_norm': optax.global_norm(gradients) } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='learner') self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._demonstrations = demonstrations # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. self._sgd_step = jax.jit(utils.process_multiple_batches( sgd_step, num_sgd_steps_per_step)) random_key, init_key = jax.random.split(random_key) policy_params = network.init(init_key) optimizer_state = optimizer.init(policy_params) # Create initial state. self._state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=random_key, steps=0, ) self._timestamp = None
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: Callable[[jnp.ndarray], jnp.ndarray], optimizer: optax.GradientTransformation, batch_size: int, epsilon: float, rng: hk.PRNGSequence, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, ): # Transform the (impure) network into a pure function. network = hk.without_apply_rng(hk.transform(network, apply_rng=True)) # Define loss function. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Computes the standard TD(0) Q-learning loss on batch of transitions.""" o_tm1, a_tm1, r_t, d_t, o_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(td_error**2) # Define update function. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Performs an SGD step on a batch of transitions.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize the networks and optimizer. dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = network.init(next(rng), dummy_observation) initial_target_params = network.init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # This carries the agent state relevant to training. self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._sgd_step = sgd_step self._forward = jax.jit(network.apply) self._replay = replay.Replay(capacity=replay_capacity) # Store hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._epsilon = epsilon self._total_steps = 0 self._min_replay_size = min_replay_size
def __init__(self, network: networks_lib.FeedForwardNetwork, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra def postprocess_aux(extra: LossExtra) -> LossExtra: reverb_update = jax.tree_map( lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_update) return extra._replace(metrics=jax.tree_map(jnp.mean, extra.metrics), reverb_update=reverb_update) self._num_sgd_steps_per_step = num_sgd_steps_per_step sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, postprocess_aux) self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None # Initialize the network parameters key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params) initial_target_params = self.network.init(key_target) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: ReverbUpdate) -> None: if replay_client is None: return keys, priorities = tree.map_structure( utils.fetch_devicearray, (reverb_update.keys, reverb_update.priorities)) replay_client.mutate_priorities(table=replay_table_name, updates=dict(zip(keys, priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__(self, observation_spec: specs.Array, action_spec: specs.DiscreteArray, network: PolicyValueNet, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, buffer_length: int, discount: float, td_lambda: float, entropy_cost: float = 1., critic_cost: float = 1.): @jax.jit def pack(trajectory: buffer.Trajectory) -> List[jnp.ndarray]: """Converts a trajectory into an input.""" observations = trajectory.observations[:, None, ...] rewards = jnp.concatenate([ trajectory.previous_reward, jnp.squeeze(trajectory.rewards, -1) ], -1) rewards = jnp.expand_dims(rewards, (1, 2)) previous_action = jax.nn.one_hot(trajectory.previous_action, action_spec.num_values) actions = jax.nn.one_hot(jnp.squeeze(trajectory.actions, 1), action_spec.num_values) actions = jnp.expand_dims( jnp.concatenate([previous_action, actions], 0), 1) return observations, rewards, actions @jax.jit def loss(trajectory: buffer.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" observations, rewards, actions = pack(trajectory) logits, values, _, _, _ = network(observations, rewards, actions) td_errors = rlax.td_lambda(v_tm1=values[:-1], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones_like(td_errors)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors))) return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: AgentState, trajectory: buffer.Trajectory) -> AgentState: """Performs a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return AgentState(params=new_params, opt_state=new_opt_state) # Initialize network parameters and optimiser state. init, forward = hk.without_apply_rng( hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *observation_spec.shape), dtype=observation_spec.dtype) dummy_reward = jnp.zeros((1, 1, 1)) dummy_action = jnp.zeros((1, 1, action_spec.num_values)) initial_params = init(next(rng), dummy_observation, dummy_reward, dummy_action) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = AgentState(initial_params, initial_opt_state) self._forward = jax.jit(forward) self._buffer = buffer.Buffer(observation_spec, action_spec, buffer_length) self._sgd_step = sgd_step self._rng = rng self._action_spec = action_spec
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, support: jnp.ndarray, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. self._statistics = {'state_value': np.nan} # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key = jax.random.split(rng_key, 3) logits_q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).q_logits logits_target_q_t = network.apply(target_params, target_key, transitions.s_t).q_logits losses = _batch_categorical_q_learning( support, logits_q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, support, logits_target_q_t, ) chex.assert_shape(losses, (self._batch_size, )) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) v_t = jnp.max(q_t, axis=-1) return rng_key, a_t, v_t self._select_action = jax.jit(select_action)
def solve_dual_train( env: Dict[int, DualOp], dual_state: ConfigDict, opt: optax.GradientTransformation, inner_opt: InnerMaxStrategy, dual_params: Params, spec_type: verify_utils.SpecType, dual_params_types: ParamsTypes, logger: Callable[[int, Mapping[str, Any]], None], key: jnp.array, num_steps: int, affine_before_relu: bool, device_type=None, merge_problems: Optional[Dict[int, int]] = None, block_to_time: bool = False, ) -> ConfigDict: """Compute verified upper bound via functional lagrangian relaxation. Args: env: Lagrangian computations for each contributing graph node. dual_state: state of the dual problem. opt: an optimizer for the outer Lagrangian parameters. inner_opt: inner optimization strategy for training. dual_params: dual parameters to be minimized via gradient-based optimization. spec_type: Specification type, adversarial or uncertainty specification. dual_params_types: types of inequality encoded by the corresponding dual_params. logger: logging function. key: jax.random.PRNGKey. num_steps: total number of outer optimization steps. affine_before_relu: whether layer ordering uses the affine layer before the ReLU. device_type: string, used to clamp to a particular hardware device. Default None uses JAX default device placement. merge_problems: the key of the dictionary corresponds to the index of the layer to begin the merge, and the associated value corresponds to the number of consecutive layers to be merged with it. For example, `{0: 2, 2: 3}` will merge together layer 0 and 1, as well as layers 2, 3 and 4. block_to_time: whether to block computations at the end of each iteration to account for asynchronicity dispatch when timing. Returns: dual_state: new state of the dual problem. info: various information for logging / debugging. """ assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type' # create dual functions loss_func = dual_build.build_dual_fun( env=env, lagrangian_form=dual_params_types.lagrangian_form, inner_opt=inner_opt, merge_problems=merge_problems, affine_before_relu=affine_before_relu, spec_type=spec_type) value_and_grad = jax.value_and_grad(loss_func, has_aux=True) def grad_step(params, opt_state, key, step): (loss_val, stats), g = value_and_grad(params, key, step) updates, new_opt_state = opt.update(g, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state, loss_val, stats # Some solvers (e.g. MIP) cannot be jitted and run on CPU only if inner_opt.jittable: grad_step = jax.jit(grad_step, backend=device_type) dual_state.step = 0 dual_state.key = key dual_state.opt_state = opt.init(dual_params) dual_state.dual_params = dual_params dual_state.loss = 0.0 dual_state.best_loss = jnp.inf dual_state.best_dual_params = dual_params # optimize the dual (Lagrange) parameters with a gradient-based optimizer while dual_state.step < num_steps: key_step, dual_state.key = jax.random.split(dual_state.key) start_time = time.time() dual_params, dual_state.opt_state, dual_state.loss, stats = grad_step( dual_state.dual_params, dual_state.opt_state, key_step, dual_state.step) dual_params = dual_build.project_dual(dual_params, dual_params_types) if dual_state.loss <= dual_state.best_loss: dual_state.best_loss = dual_state.loss # store value from previous iteration as loss corresponds to those params dual_state.best_dual_params = dual_state.dual_params dual_state.dual_params = dual_params # projected dual params if block_to_time: dual_state.loss.block_until_ready() # asynchronous dispatch stats['time_per_iteration'] = time.time() - start_time stats['best_loss'] = dual_state.best_loss stats['dual_params_norm'] = optax.global_norm(dual_state.dual_params) logger(dual_state.step, stats) dual_state.step += 1 return dual_state
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: Callable[[jnp.ndarray], jnp.ndarray], num_ensemble: int, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: optax.GradientTransformation, mask_prob: float, noise_scale: float, epsilon_fn: Callable[[int], float] = lambda _: 0., seed: int = 1, ): # Transform the (impure) network into a pure function. network = hk.without_apply_rng(hk.transform(network, apply_rng=True)) # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Q-learning loss with added reward noise + half-in bootstrap.""" o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) r_t += noise_scale * z_t batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(m_t * td_error**2) # Define update function for each member of ensemble.. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Does a step of SGD for the whole ensemble over `transitions`.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize parameters and optimizer state for an ensemble of Q-networks. rng = hk.PRNGSequence(seed) dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_target_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_opt_state = [optimizer.init(p) for p in initial_params] # Internalize state. self._ensemble = [ TrainingState(p, tp, o, step=0) for p, tp, o in zip( initial_params, initial_target_params, initial_opt_state) ] self._forward = jax.jit(network.apply) self._sgd_step = sgd_step self._num_ensemble = num_ensemble self._optimizer = optimizer self._replay = replay.Replay(capacity=replay_capacity) # Agent hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._min_replay_size = min_replay_size self._epsilon_fn = epsilon_fn self._mask_prob = mask_prob # Agent state. self._active_head = self._ensemble[0] self._total_steps = 0
def __init__(self, unroll: networks_lib.FeedForwardNetwork, initial_state: networks_lib.FeedForwardNetwork, batch_size: int, random_key: networks_lib.PRNGKey, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, bootstrap_n: int = 5, tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, clip_rewards: bool = False, max_abs_reward: float = 1., use_core_state: bool = True, prefetch_size: int = 2, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the learner.""" random_key, key_initial_1, key_initial_2 = jax.random.split( random_key, 3) initial_state_params = initial_state.init(key_initial_1, batch_size) initial_state = initial_state.apply(initial_state_params, key_initial_2, batch_size) def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """Computes mean transformed N-step loss for a batch of sequences.""" # Convert sample data to sequence-major format [T, B, ...]. data = utils.batch_to_sequence(sample.data) # Get core state & warm it up on observations for a burn-in period. if use_core_state: # Replay core state. online_state = jax.tree_map(lambda x: x[0], data.extras['core_state']) else: online_state = initial_state target_state = online_state # Maybe burn the core state in. if burn_in_length: burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = unroll.apply(params, key1, burn_obs, online_state) _, target_state = unroll.apply(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. data = jax.tree_map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) online_q, _ = unroll.apply(params, key1, data.observation, online_state) target_q, _ = unroll.apply(target_params, key2, data.observation, target_state) # Get value-selector actions from online Q-values for double Q-learning. selector_actions = jnp.argmax(online_q, axis=-1) # Preprocess discounts & rewards. discounts = (data.discount * discount).astype(online_q.dtype) rewards = data.reward if clip_rewards: rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) rewards = rewards.astype(online_q.dtype) # Get N-step transformed TD error and loss. batch_td_error_fn = jax.vmap(functools.partial( rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair), in_axes=1, out_axes=1) # TODO(b/183945808): when this bug is fixed, truncations of actions, # rewards, and discounts will no longer be necessary. batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1], target_q[1:], selector_actions[1:], rewards[:-1], discounts[:-1]) batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) # Importance weighting. probs = sample.info.probability importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) mean_loss = jnp.mean(importance_weights * batch_loss) # Calculate priorities as a mixture of max and mean sequence errors. abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) priorities = (max_priority + mean_priority) return mean_loss, priorities def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value} def update_priorities(keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): keys, priorities = keys_and_priorities keys, priorities = tree.map_structure( # Fetch array and combine device and batch dimensions. lambda x: utils.fetch_devicearray(x).reshape( (-1, ) + x.shape[2:]), (keys, priorities)) replay_client.mutate_priorities( # pytype: disable=attribute-error table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(keys, priorities))) # Internalise components, hyperparameters, logger, counter, and methods. self._replay_client = replay_client self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, time_delta=1.) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities) # Initialise and internalise training state (parameters/optimiser state). random_key, key_init = jax.random.split(random_key) initial_params = unroll.init(key_init, initial_state) opt_state = optimizer.init(initial_params) state = TrainingState(params=initial_params, target_params=initial_params, opt_state=opt_state, steps=jnp.array(0), random_key=random_key) # Replicate parameters. self._state = utils.replicate_in_all_devices(state) # Shard multiple inputs with on-device prefetching. # We split samples in two outputs, the keys which need to be kept on-host # since int64 arrays are not supported in TPUs, and the entire sample # separately so it can be sent to the sgd_step method. def split_sample( sample: reverb.ReplaySample) -> utils.PrefetchingSplit: return utils.PrefetchingSplit(host=sample.info.key, device=sample) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, num_threads=jax.local_device_count(), split_fn=split_sample)
def __init__(self, network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, loss_fn: losses.Loss, optimizer: optax.GradientTransformation, prefetching_iterator: Iterator[types.Transition], num_sgd_steps_per_step: int, loss_has_aux: bool = False, logger: Optional[loggers.Logger] = None, counter: Optional[counting.Counter] = None): """Behavior Cloning Learner. Args: network: Networks with signature for apply: (params, obs, is_training, key) -> jnp.ndarray and for init: (rng, is_training) -> params random_key: RNG key. loss_fn: BC loss to use. optimizer: Optax optimizer. prefetching_iterator: A sharded prefetching iterator as outputted from `acme.jax.utils.sharded_prefetch`. Please see the documentation for `sharded_prefetch` for more details. num_sgd_steps_per_step: Number of gradient updates per step. loss_has_aux: Whether the loss function returns auxiliary metrics as a second argument. logger: Logger. counter: Counter. """ def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1, has_aux=loss_has_aux) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_result, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) # Combine the gradient across all devices (by taking their mean). gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) # Compute and combine metrics across all devices. metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) policy_update, optimizer_state = optimizer.update( gradients, state.optimizer_state, state.policy_params) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='learner') self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) random_key, init_key = jax.random.split(random_key) policy_params = network.init(init_key) optimizer_state = optimizer.init(policy_params) # Create initial state. state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=random_key, steps=0, ) self._state = utils.replicate_in_all_devices(state) self._timestamp = None self._prefetching_iterator = prefetching_iterator
def __init__(self, networks: value_dice_networks.ValueDiceNetworks, policy_optimizer: optax.GradientTransformation, nu_optimizer: optax.GradientTransformation, discount: float, rng: jnp.ndarray, iterator_replay: Iterator[reverb.ReplaySample], iterator_demonstrations: Iterator[types.Transition], alpha: float = 0.05, policy_reg_scale: float = 1e-4, nu_reg_scale: float = 10.0, num_sgd_steps_per_step: int = 1, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): rng, policy_key, nu_key = jax.random.split(rng, 3) policy_init_params = networks.policy_network.init(policy_key) policy_optimizer_state = policy_optimizer.init(policy_init_params) nu_init_params = networks.nu_network.init(nu_key) nu_optimizer_state = nu_optimizer.init(nu_init_params) def compute_losses( policy_params: networks_lib.Params, nu_params: networks_lib.Params, key: jnp.ndarray, replay_o_tm1: types.NestedArray, replay_a_tm1: types.NestedArray, replay_o_t: types.NestedArray, demo_o_tm1: types.NestedArray, demo_a_tm1: types.NestedArray, demo_o_t: types.NestedArray, ) -> jnp.ndarray: # TODO(damienv, hussenot): what to do with the discounts ? def policy(obs, key): dist_params = networks.policy_network.apply(policy_params, obs) return networks.sample(dist_params, key) key1, key2, key3, key4 = jax.random.split(key, 4) # Predicted actions. demo_o_t0 = demo_o_tm1 policy_demo_a_t0 = policy(demo_o_t0, key1) policy_demo_a_t = policy(demo_o_t, key2) policy_replay_a_t = policy(replay_o_t, key3) replay_a_tm1 = networks.encode_action(replay_a_tm1) demo_a_tm1 = networks.encode_action(demo_a_tm1) policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0) policy_demo_a_t = networks.encode_action(policy_demo_a_t) policy_replay_a_t = networks.encode_action(policy_replay_a_t) # "Value function" nu over the expert states. nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0, policy_demo_a_t0) nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1) nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t, policy_demo_a_t) nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t # "Value function" nu over the replay buffer states. nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1, replay_a_tm1) nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t, policy_replay_a_t) nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t # Linear part of the loss. linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount)) linear_loss_rb = jnp.mean(nu_replay_diff) linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha) # Non linear part of the loss. nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff], axis=0) replay_demo_weights = jnp.concatenate([ jnp.ones_like(nu_demo_diff) * (1 - alpha), jnp.ones_like(nu_replay_diff) * alpha ], axis=0) replay_demo_weights /= jnp.mean(replay_demo_weights) non_linear_loss = jnp.sum( jax.lax.stop_gradient( utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights, axis=0)) * nu_replay_demo_diff) # Final loss. loss = (non_linear_loss - linear_loss) # Regularized policy loss. if policy_reg_scale > 0.: policy_reg = _orthogonal_regularization_loss(policy_params) else: policy_reg = 0. # Gradient penality on nu if nu_reg_scale > 0.0: batch_size = demo_o_tm1.shape[0] c = jax.random.uniform(key4, shape=(batch_size,)) shape_o = [ dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape) ] shape_a = [ dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape) ] c_o = jnp.reshape(c, shape_o) c_a = jnp.reshape(c, shape_a) mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1 mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1 mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0) mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0) def sum_nu(o, a): return jnp.sum(networks.nu_network.apply(nu_params, o, a)) nu_grad_o_fn = jax.grad(sum_nu, argnums=0) nu_grad_a_fn = jax.grad(sum_nu, argnums=1) nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a) nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a) nu_grad = jnp.concatenate([ jnp.reshape(nu_grad_o, [batch_size, -1]), jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1) # TODO(damienv, hussenot): check for the need of eps # (like in the original value dice code). nu_grad_penalty = jnp.mean( jnp.square( jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1)) else: nu_grad_penalty = 0.0 policy_loss = -loss + policy_reg_scale * policy_reg nu_loss = loss + nu_reg_scale * nu_grad_penalty return policy_loss, nu_loss def sgd_step( state: TrainingState, data: Tuple[types.Transition, types.Transition] ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: replay_transitions, demo_transitions = data key, key_loss = jax.random.split(state.key) compute_losses_with_input = functools.partial( compute_losses, replay_o_tm1=replay_transitions.observation, replay_a_tm1=replay_transitions.action, replay_o_t=replay_transitions.next_observation, demo_o_tm1=demo_transitions.observation, demo_a_tm1=demo_transitions.action, demo_o_t=demo_transitions.next_observation, key=key_loss) (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( compute_losses_with_input, state.policy_params, state.nu_params) policy_gradients, _ = vjpfun((1.0, 0.0)) _, nu_gradients = vjpfun((0.0, 1.0)) # Update optimizers. policy_update, policy_optimizer_state = policy_optimizer.update( policy_gradients, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) nu_update, nu_optimizer_state = nu_optimizer.update( nu_gradients, state.nu_optimizer_state) nu_params = optax.apply_updates(state.nu_params, nu_update) new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, nu_optimizer_state=nu_optimizer_state, nu_params=nu_params, key=key, steps=state.steps + 1, ) metrics = { 'policy_loss': policy_loss_value, 'nu_loss': nu_loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Iterator on demonstration transitions. self._iterator_demonstrations = iterator_demonstrations self._iterator_replay = iterator_replay self._sgd_step = jax.jit(utils.process_multiple_batches( sgd_step, num_sgd_steps_per_step)) # Create initial state. self._state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_init_params, nu_optimizer_state=nu_optimizer_state, nu_params=nu_init_params, key=rng, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: networks_lib.FeedForwardNetwork, obs_spec: specs.Array, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialize the network parameters dummy_obs = utils.add_batch_dim(utils.zeros_like(obs_spec)) key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params, dummy_obs) initial_target_params = self.network.init(key_target, dummy_obs) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: Optional[ReverbUpdate]) -> None: if reverb_update is None or replay_client is None: return else: replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates=dict( zip(reverb_update.keys, reverb_update.priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__( self, preprocessor: processors.Processor, sample_network_input: IqnInputs, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, huber_param: float, tau_samples_policy: int, tau_samples_s_tm1: int, tau_samples_s_t: int, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init( network_rng_key, jax.tree_map(lambda x: x[None, ...], sample_network_input)) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" # Sample tau values for q_tm1, q_t_selector, q_t. batch_size = self._batch_size rng_key, *sample_keys = jax.random.split(rng_key, 4) tau_tm1 = _sample_tau(sample_keys[0], (batch_size, tau_samples_s_tm1)) tau_t_selector = _sample_tau(sample_keys[1], (batch_size, tau_samples_policy)) tau_t = _sample_tau(sample_keys[2], (batch_size, tau_samples_s_t)) # Compute Q value distributions. _, *apply_keys = jax.random.split(rng_key, 4) dist_q_tm1 = network.apply(online_params, apply_keys[0], IqnInputs(transitions.s_tm1, tau_tm1)).q_dist dist_q_t_selector = network.apply( target_params, apply_keys[1], IqnInputs(transitions.s_t, tau_t_selector)).q_dist dist_q_target_t = network.apply(target_params, apply_keys[2], IqnInputs(transitions.s_t, tau_t)).q_dist losses = _batch_quantile_q_learning( dist_q_tm1, tau_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, dist_q_t_selector, dist_q_target_t, huber_param, ) assert losses.shape == (self._batch_size,) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, sample_key, apply_key, policy_key = jax.random.split(rng_key, 4) tau_t = _sample_tau(sample_key, (1, tau_samples_policy)) q_t = network.apply(network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: replay_lib.TransitionAccumulator, replay: replay_lib.PrioritizedTransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, grad_error_bound: float, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. self._max_seen_priority = 1. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, weights, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).q_values q_t = network.apply(online_params, apply_keys[1], transitions.s_t).q_values q_target_t = network.apply(target_params, apply_keys[2], transitions.s_t).q_values td_errors = _batch_double_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, q_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size, ) == weights.shape # This is not the same as using a huber loss and multiplying by weights. loss = jnp.mean(losses * weights) return loss, td_errors def update(rng_key, opt_state, online_params, target_params, transitions, weights): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params, td_errors = jax.grad(loss_fn, has_aux=True)( online_params, target_params, transitions, weights, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params, td_errors self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)