def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: PolicyValueNet, optimizer: optix.InitUpdate, rng: hk.PRNGSequence, sequence_length: int, discount: float, td_lambda: float, ): # Define loss function. def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" logits, values = network(trajectory.observations) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones_like(td_errors)) return actor_loss + critic_loss # Transform the loss into a pure function. loss_fn = hk.transform(loss).apply # Define update function. @jax.jit def sgd_step(state: TrainingState, trajectory: sequence.Trajectory) -> TrainingState: """Does a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state) # Initialize network parameters and optimiser state. init, forward = hk.transform(network) dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32) initial_params = init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = TrainingState(initial_params, initial_opt_state) self._forward = jax.jit(forward) self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length) self._sgd_step = sgd_step self._rng = rng
def __init__(self, network: networks.QNetwork, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optix.InitUpdate, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" # Transform network into a pure function. network = hk.transform(network) def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities) def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) new_state = TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs def update_priorities(outputs: LearnerOutputs): for key, priority in zip(outputs.keys, outputs.priorities): replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates={key: priority}) # Internalise agent components (replay buffer, networks, optimizer). self._replay_client = replay_client self._iterator = utils.prefetch(iterator) # Internalise the hyperparameters. self._target_update_period = target_update_period # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_target_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._forward = jax.jit(network.apply) self._sgd_step = jax.jit(sgd_step) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: hk.Transformed, num_ensemble: int, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: optix.InitUpdate, mask_prob: float, noise_scale: float, epsilon_fn: Callable[[int], float] = lambda _: 0., seed: int = 1, ): """Bootstrapped DQN with randomized prior functions.""" # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Q-learning loss with added reward noise + half-in bootstrap.""" o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) r_t += noise_scale * z_t batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(m_t * td_error**2) # Define update function for each member of ensemble.. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Does a step of SGD for the whole ensemble over `transitions`.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize parameters and optimizer state for an ensemble of Q-networks. rng = hk.PRNGSequence(seed) dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_target_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_opt_state = [optimizer.init(p) for p in initial_params] # Internalize state. self._ensemble = [ TrainingState(p, tp, o, step=0) for p, tp, o in zip( initial_params, initial_target_params, initial_opt_state) ] self._forward = jax.jit(network.apply) self._sgd_step = sgd_step self._num_ensemble = num_ensemble self._optimizer = optimizer self._replay = replay.Replay(capacity=replay_capacity) # Agent hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._min_replay_size = min_replay_size self._epsilon_fn = epsilon_fn self._mask_prob = mask_prob # Agent state. self._active_head = self._ensemble[0] self._total_steps = 0
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: hk.Transformed, optimizer: optix.InitUpdate, batch_size: int, epsilon: float, rng: hk.PRNGSequence, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, ): # Define loss function. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Computes the standard TD(0) Q-learning loss on batch of transitions.""" o_tm1, a_tm1, r_t, d_t, o_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(td_error**2) # Define update function. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Performs an SGD step on a batch of transitions.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize the networks and optimizer. dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = network.init(next(rng), dummy_observation) initial_target_params = network.init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # This carries the agent state relevant to training. self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._sgd_step = sgd_step self._forward = jax.jit(network.apply) self._replay = replay.Replay(capacity=replay_capacity) # Store hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._epsilon = epsilon self._total_steps = 0 self._min_replay_size = min_replay_size
def __init__( self, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], obs_spec: specs.Array, iterator: Iterator[reverb.ReplaySample], optimizer: optix.InitUpdate, rng: hk.PRNGSequence, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Initialise training state (parameters & optimiser state). network = hk.transform(network) initial_network_state = hk.transform(initial_state_fn).apply(None) initial_params = network.init(next(rng), jax_utils.zeros_like(obs_spec), initial_network_state) initial_opt_state = optimizer.init(initial_params) def loss(params: hk.Params, sample: reverb.ReplaySample): """V-trace loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. net = functools.partial(network.apply, params) (logits, values), _ = hk.static_unroll(net, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards, discount_t=discounts * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions, adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards)) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards)) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss @jax.jit def sgd_step(state: TrainingState, sample: reverb.ReplaySample): # Compute gradients and optionally apply clipping. batch_loss = jax.vmap(loss, in_axes=(None, 0)) mean_loss = lambda p, s: jnp.mean(batch_loss(p, s)) grad_fn = jax.value_and_grad(mean_loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics self._state = TrainingState(params=initial_params, opt_state=initial_opt_state) # Internalise iterator. self._iterator = jax_utils.prefetch(iterator) self._sgd_step = sgd_step # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: RecurrentPolicyValueNet, initial_rnn_state: LSTMState, optimizer: optix.InitUpdate, rng: hk.PRNGSequence, sequence_length: int, discount: float, td_lambda: float, entropy_cost: float = 0., ): # Define loss function. def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState): """"Actor-critic loss.""" (logits, values), new_rnn_unroll_state = hk.dynamic_unroll( network, trajectory.observations[:, None, ...], rnn_unroll_state) seq_len = trajectory.actions.shape[0] td_errors = rlax.td_lambda( v_tm1=values[:-1, 0], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones(seq_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len))) combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss return combined_loss, new_rnn_unroll_state # Transform the loss into a pure function. loss_fn = hk.without_apply_rng(hk.transform(loss, apply_rng=True)).apply # Define update function. @jax.jit def sgd_step(state: AgentState, trajectory: sequence.Trajectory) -> AgentState: """Does a step of SGD over a trajectory.""" gradients, new_rnn_state = jax.grad( loss_fn, has_aux=True)(state.params, trajectory, state.rnn_unroll_state) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return state._replace( params=new_params, opt_state=new_opt_state, rnn_unroll_state=new_rnn_state) # Initialize network parameters and optimiser state. init, forward = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=obs_spec.dtype) initial_params = init(next(rng), dummy_observation, initial_rnn_state) initial_opt_state = optimizer.init(initial_params) # Internalize state. self._state = AgentState(initial_params, initial_opt_state, initial_rnn_state, initial_rnn_state) self._forward = jax.jit(forward) self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length) self._sgd_step = sgd_step self._rng = rng self._initial_rnn_state = initial_rnn_state