def __init__(self, direct_rl_learner_factory: Callable[ [Any, Iterator[reverb.ReplaySample]], acme.Learner], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rnd_network: rnd_networks.RNDNetworks, rng_key: jnp.ndarray, grad_updates_per_batch: int, is_sequence_based: bool, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): self._is_sequence_based = is_sequence_based target_key, predictor_key = jax.random.split(rng_key) target_params = rnd_network.target.init(target_key) predictor_params = rnd_network.predictor.init(predictor_key) optimizer_state = optimizer.init(predictor_params) self._state = RNDTrainingState(optimizer_state=optimizer_state, params=predictor_params, target_params=target_params, steps=0) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) loss = functools.partial(rnd_loss, networks=rnd_network) self._update = functools.partial(rnd_update_step, loss_fn=loss, optimizer=optimizer) self._update = utils.process_multiple_batches(self._update, grad_updates_per_batch) self._update = jax.jit(self._update) self._get_reward = jax.jit( functools.partial(rnd_networks.compute_rnd_reward, networks=rnd_network)) # Generator expression that works the same as an iterator. # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions updated_iterator = (self._process_sample(sample) for sample in iterator) self._direct_rl_learner = direct_rl_learner_factory( rnd_network.direct_rl_networks, updated_iterator) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
# jax.tree_map(lambda x: jnp.std(x, axis=0), # transitions.next_observation))) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._iterator = iterator # update_step = utils.process_multiple_batches(update_step, # num_sgd_steps_per_step) self._update_step_in_initial_bc_iters = utils.process_multiple_batches( lambda x, y: update_step(x, y, True), num_sgd_steps_per_step) self._update_step_rest = utils.process_multiple_batches( lambda x, y: update_step(x, y, False), num_sgd_steps_per_step) # Use the JIT compiler. self._update_step = jax.jit(update_step) def make_initial_state(key): """Initialises the training state (parameters and optimiser state).""" key_policy, key_q, key = jax.random.split(key, 3) devices = jax.local_devices() policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) policy_params = jax.device_put_replicated(policy_params, devices) policy_optimizer_state = jax.device_put_replicated(
def __init__(self, networks: value_dice_networks.ValueDiceNetworks, policy_optimizer: optax.GradientTransformation, nu_optimizer: optax.GradientTransformation, discount: float, rng: jnp.ndarray, iterator_replay: Iterator[reverb.ReplaySample], iterator_demonstrations: Iterator[types.Transition], alpha: float = 0.05, policy_reg_scale: float = 1e-4, nu_reg_scale: float = 10.0, num_sgd_steps_per_step: int = 1, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): rng, policy_key, nu_key = jax.random.split(rng, 3) policy_init_params = networks.policy_network.init(policy_key) policy_optimizer_state = policy_optimizer.init(policy_init_params) nu_init_params = networks.nu_network.init(nu_key) nu_optimizer_state = nu_optimizer.init(nu_init_params) def compute_losses( policy_params: networks_lib.Params, nu_params: networks_lib.Params, key: jnp.ndarray, replay_o_tm1: types.NestedArray, replay_a_tm1: types.NestedArray, replay_o_t: types.NestedArray, demo_o_tm1: types.NestedArray, demo_a_tm1: types.NestedArray, demo_o_t: types.NestedArray, ) -> jnp.ndarray: # TODO(damienv, hussenot): what to do with the discounts ? def policy(obs, key): dist_params = networks.policy_network.apply(policy_params, obs) return networks.sample(dist_params, key) key1, key2, key3, key4 = jax.random.split(key, 4) # Predicted actions. demo_o_t0 = demo_o_tm1 policy_demo_a_t0 = policy(demo_o_t0, key1) policy_demo_a_t = policy(demo_o_t, key2) policy_replay_a_t = policy(replay_o_t, key3) replay_a_tm1 = networks.encode_action(replay_a_tm1) demo_a_tm1 = networks.encode_action(demo_a_tm1) policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0) policy_demo_a_t = networks.encode_action(policy_demo_a_t) policy_replay_a_t = networks.encode_action(policy_replay_a_t) # "Value function" nu over the expert states. nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0, policy_demo_a_t0) nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1) nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t, policy_demo_a_t) nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t # "Value function" nu over the replay buffer states. nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1, replay_a_tm1) nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t, policy_replay_a_t) nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t # Linear part of the loss. linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount)) linear_loss_rb = jnp.mean(nu_replay_diff) linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha) # Non linear part of the loss. nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff], axis=0) replay_demo_weights = jnp.concatenate([ jnp.ones_like(nu_demo_diff) * (1 - alpha), jnp.ones_like(nu_replay_diff) * alpha ], axis=0) replay_demo_weights /= jnp.mean(replay_demo_weights) non_linear_loss = jnp.sum( jax.lax.stop_gradient( utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights, axis=0)) * nu_replay_demo_diff) # Final loss. loss = (non_linear_loss - linear_loss) # Regularized policy loss. if policy_reg_scale > 0.: policy_reg = _orthogonal_regularization_loss(policy_params) else: policy_reg = 0. # Gradient penality on nu if nu_reg_scale > 0.0: batch_size = demo_o_tm1.shape[0] c = jax.random.uniform(key4, shape=(batch_size,)) shape_o = [ dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape) ] shape_a = [ dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape) ] c_o = jnp.reshape(c, shape_o) c_a = jnp.reshape(c, shape_a) mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1 mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1 mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0) mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0) def sum_nu(o, a): return jnp.sum(networks.nu_network.apply(nu_params, o, a)) nu_grad_o_fn = jax.grad(sum_nu, argnums=0) nu_grad_a_fn = jax.grad(sum_nu, argnums=1) nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a) nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a) nu_grad = jnp.concatenate([ jnp.reshape(nu_grad_o, [batch_size, -1]), jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1) # TODO(damienv, hussenot): check for the need of eps # (like in the original value dice code). nu_grad_penalty = jnp.mean( jnp.square( jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1)) else: nu_grad_penalty = 0.0 policy_loss = -loss + policy_reg_scale * policy_reg nu_loss = loss + nu_reg_scale * nu_grad_penalty return policy_loss, nu_loss def sgd_step( state: TrainingState, data: Tuple[types.Transition, types.Transition] ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: replay_transitions, demo_transitions = data key, key_loss = jax.random.split(state.key) compute_losses_with_input = functools.partial( compute_losses, replay_o_tm1=replay_transitions.observation, replay_a_tm1=replay_transitions.action, replay_o_t=replay_transitions.next_observation, demo_o_tm1=demo_transitions.observation, demo_a_tm1=demo_transitions.action, demo_o_t=demo_transitions.next_observation, key=key_loss) (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( compute_losses_with_input, state.policy_params, state.nu_params) policy_gradients, _ = vjpfun((1.0, 0.0)) _, nu_gradients = vjpfun((0.0, 1.0)) # Update optimizers. policy_update, policy_optimizer_state = policy_optimizer.update( policy_gradients, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) nu_update, nu_optimizer_state = nu_optimizer.update( nu_gradients, state.nu_optimizer_state) nu_params = optax.apply_updates(state.nu_params, nu_update) new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, nu_optimizer_state=nu_optimizer_state, nu_params=nu_params, key=key, steps=state.steps + 1, ) metrics = { 'policy_loss': policy_loss_value, 'nu_loss': nu_loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Iterator on demonstration transitions. self._iterator_demonstrations = iterator_demonstrations self._iterator_replay = iterator_replay self._sgd_step = jax.jit(utils.process_multiple_batches( sgd_step, num_sgd_steps_per_step)) # Create initial state. self._state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_init_params, nu_optimizer_state=nu_optimizer_state, nu_params=nu_init_params, key=rng, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: networks_lib.FeedForwardNetwork, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra def postprocess_aux(extra: LossExtra) -> LossExtra: reverb_update = jax.tree_map( lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_update) return extra._replace(metrics=jax.tree_map(jnp.mean, extra.metrics), reverb_update=reverb_update) self._num_sgd_steps_per_step = num_sgd_steps_per_step sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, postprocess_aux) self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None # Initialize the network parameters key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params) initial_target_params = self.network.init(key_target) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: ReverbUpdate) -> None: if replay_client is None: return keys, priorities = tree.map_structure( utils.fetch_devicearray, (reverb_update.keys, reverb_update.priorities)) replay_client.mutate_priorities(table=replay_table_name, updates=dict(zip(keys, priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__(self, batch_size: int, networks: CQLNetworks, random_key: networks_lib.PRNGKey, demonstrations: Iterator[types.Transition], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, tau: float = 0.005, fixed_cql_coefficient: Optional[float] = None, cql_lagrange_threshold: Optional[float] = None, cql_num_samples: int = 10, num_sgd_steps_per_step: int = 1, reward_scale: float = 1.0, discount: float = 0.99, fixed_entropy_coefficient: Optional[float] = None, target_entropy: Optional[float] = 0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the CQL learner. Args: batch_size: bath size. networks: CQL networks. random_key: a key for random number generation. demonstrations: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. tau: target smoothing coefficient. fixed_cql_coefficient: the value for cql coefficient. If None, an adaptive coefficient will be used. cql_lagrange_threshold: a threshold that controls the adaptive loss for the cql coefficient. cql_num_samples: number of samples used to compute logsumexp(Q) via importance sampling. num_sgd_steps_per_step: how many gradient updated to perform per batch. batch is split into this many smaller batches, thus should be a multiple of num_sgd_steps_per_step reward_scale: reward scale. discount: discount to use for TD updates. fixed_entropy_coefficient: coefficient applied to the entropy bonus. If None, an adaptative coefficient will be used. target_entropy: Target entropy when using adapdative entropy bonus. counter: counter object used to keep track of steps. logger: logger object to be used by learner. """ adaptive_entropy_coefficient = fixed_entropy_coefficient is None action_spec = networks.environment_specs.actions if adaptive_entropy_coefficient: # sac_alpha is the temperature parameter that determines the relative # importance of the entropy term versus the reward. log_sac_alpha = jnp.asarray(0., dtype=jnp.float32) alpha_optimizer = optax.adam(learning_rate=3e-4) alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha) else: if target_entropy: raise ValueError('target_entropy should not be set when ' 'fixed_entropy_coefficient is provided') adaptive_cql_coefficient = fixed_cql_coefficient is None if adaptive_cql_coefficient: log_cql_alpha = jnp.asarray(0., dtype=jnp.float32) cql_optimizer = optax.adam(learning_rate=3e-4) cql_optimizer_state = cql_optimizer.init(log_cql_alpha) else: if cql_lagrange_threshold: raise ValueError( 'cql_lagrange_threshold should not be set when ' 'fixed_cql_coefficient is provided') def alpha_loss(log_sac_alpha: jnp.ndarray, policy_params: networks_lib.Params, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) sac_alpha = jnp.exp(log_sac_alpha) sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob - target_entropy) return jnp.mean(sac_alpha_loss) def sac_critic_loss(q_old_action: jnp.ndarray, policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the SAC part of the loss.""" next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_action = networks.sample(next_dist_params, key) next_log_prob = networks.log_prob(next_dist_params, next_action) next_q = networks.critic_network.apply( target_critic_params, transitions.next_observation, next_action) next_v = jnp.min(next_q, axis=-1) - sac_alpha * next_log_prob target_q = jax.lax.stop_gradient( transitions.reward * reward_scale + transitions.discount * discount * next_v) return jnp.mean( jnp.square(q_old_action - jnp.expand_dims(target_q, -1))) def batched_critic(actions: jnp.ndarray, critic_params: networks_lib.Params, observation: jnp.ndarray) -> jnp.ndarray: """Applies the critic network to a batch of sampled actions.""" actions = jax.lax.stop_gradient(actions) tiled_actions = jnp.reshape(actions, (batch_size * cql_num_samples, -1)) tiled_states = jnp.tile(observation, [cql_num_samples, 1]) tiled_q = networks.critic_network.apply(critic_params, tiled_states, tiled_actions) return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1)) def cql_critic_loss(q_old_action: jnp.ndarray, critic_params: networks_lib.Params, policy_params: networks_lib.Params, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the CQL part of the loss.""" # The CQL part of the loss is # logsumexp(Q(s,·)) - Q(s,a), # where s is the currrent state, and a the action in the dataset (so # Q(s,a) is simply q_old_action. # We need to estimate logsumexp(Q). This is done with importance sampling # (IS). This function implements the unlabeled equation page 29, Appx. F, # in https://arxiv.org/abs/2006.04779. # Here, IS is done with the uniform distribution and the policy in the # current state s. In their implementation, the authors also add the # policy in the transiting state s': # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py, # (l. 233-236). key_policy, key_policy_next, key_uniform = jax.random.split(key, 3) def sampled_q(obs, key): actions, log_probs = apply_and_sample_n( key, networks, policy_params, obs, cql_num_samples) return batched_critic( actions, critic_params, transitions.observation) - jax.lax.stop_gradient( jnp.expand_dims(log_probs, -1)) # Sample wrt policy in s sampled_q_from_policy = sampled_q(transitions.observation, key_policy) # Sample wrt policy in s' sampled_q_from_policy_next = sampled_q( transitions.next_observation, key_policy_next) # Sample wrt uniform actions_uniform = jax.random.uniform( key_uniform, (cql_num_samples, batch_size) + action_spec.shape, minval=action_spec.minimum, maxval=action_spec.maximum) log_prob_uniform = -jnp.sum( jnp.log(action_spec.maximum - action_spec.minimum)) sampled_q_from_uniform = (batched_critic( actions_uniform, critic_params, transitions.observation) - log_prob_uniform) # Combine the samplings combined = jnp.concatenate( (sampled_q_from_uniform, sampled_q_from_policy, sampled_q_from_policy_next), axis=0) lse_q = jax.nn.logsumexp(combined, axis=0, b=1. / (3 * cql_num_samples)) return jnp.mean(lse_q - q_old_action) def critic_loss(critic_params: networks_lib.Params, policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, cql_alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Computes the full critic loss.""" key_cql, key_sac = jax.random.split(key, 2) q_old_action = networks.critic_network.apply( critic_params, transitions.observation, transitions.action) cql_loss = cql_critic_loss(q_old_action, critic_params, policy_params, transitions, key_cql) sac_loss = sac_critic_loss(q_old_action, policy_params, target_critic_params, sac_alpha, transitions, key_sac) return cql_alpha * cql_loss + sac_loss def cql_lagrange_loss(log_cql_alpha: jnp.ndarray, critic_params: networks_lib.Params, policy_params: networks_lib.Params, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Computes the loss that optimizes the cql coefficient.""" cql_alpha = jnp.exp(log_cql_alpha) q_old_action = networks.critic_network.apply( critic_params, transitions.observation, transitions.action) return -cql_alpha * (cql_critic_loss( q_old_action, critic_params, policy_params, transitions, key) - cql_lagrange_threshold) def actor_loss(policy_params: networks_lib.Params, critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Computes the loss for the policy.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.critic_network.apply(critic_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) return jnp.mean(sac_alpha * log_prob - min_q) alpha_grad = jax.value_and_grad(alpha_loss) cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss) critic_grad = jax.value_and_grad(critic_loss) actor_grad = jax.value_and_grad(actor_loss) def update_step( state: TrainingState, rb_transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha, state.policy_params, rb_transitions, key_alpha) sac_alpha = jnp.exp(state.log_sac_alpha) else: sac_alpha = fixed_entropy_coefficient if adaptive_cql_coefficient: cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad( state.log_cql_alpha, state.critic_params, state.policy_params, rb_transitions, key_critic) cql_lagrange_grads = jnp.clip(cql_lagrange_grads, -_CQL_GRAD_CLIPPING_VALUE, _CQL_GRAD_CLIPPING_VALUE) cql_alpha = jnp.exp(state.log_cql_alpha) cql_alpha = jnp.clip(cql_alpha, a_min=0., a_max=_CQL_COEFFICIENT_MAX_VALUE) else: cql_alpha = fixed_cql_coefficient critic_loss, critic_grads = critic_grad(state.critic_params, state.policy_params, state.target_critic_params, sac_alpha, cql_alpha, rb_transitions, key_critic) actor_loss, actor_grads = actor_grad(state.policy_params, state.critic_params, sac_alpha, rb_transitions, key_actor) # Apply policy gradients actor_update, policy_optimizer_state = policy_optimizer.update( actor_grads, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, actor_update) # Apply critic gradients critic_update, critic_optimizer_state = critic_optimizer.update( critic_grads, state.critic_optimizer_state) critic_params = optax.apply_updates(state.critic_params, critic_update) new_target_critic_params = jax.tree_multimap( lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params, critic_params) metrics = { 'critic_loss': critic_loss, 'actor_loss': actor_loss, } new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, critic_optimizer_state=critic_optimizer_state, policy_params=policy_params, critic_params=critic_params, target_critic_params=new_target_critic_params, key=key, steps=state.steps + 1, ) if adaptive_entropy_coefficient: # Apply sac_alpha gradients alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) log_sac_alpha = optax.apply_updates(state.log_sac_alpha, alpha_update) metrics.update({ 'alpha_loss': alpha_loss, 'sac_alpha': jnp.exp(log_sac_alpha), }) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, log_sac_alpha=log_sac_alpha) if adaptive_cql_coefficient: # Apply cql coeff gradients cql_update, cql_optimizer_state = cql_optimizer.update( cql_lagrange_grads, state.cql_optimizer_state) log_cql_alpha = optax.apply_updates(state.log_cql_alpha, cql_update) metrics.update({ 'cql_lagrange_loss': cql_lagrange_loss, 'cql_alpha': jnp.exp(log_cql_alpha), }) new_state = new_state._replace( cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._demonstrations = demonstrations # Use the JIT compiler. self._update_step = utils.process_multiple_batches( update_step, num_sgd_steps_per_step) self._update_step = jax.jit(self._update_step) # Create initial state. key_policy, key_q, training_state_key = jax.random.split(random_key, 3) del random_key policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) critic_params = networks.critic_network.init(key_q) critic_optimizer_state = critic_optimizer.init(critic_params) self._state = TrainingState( policy_optimizer_state=policy_optimizer_state, critic_optimizer_state=critic_optimizer_state, policy_params=policy_params, critic_params=critic_params, target_critic_params=critic_params, key=training_state_key, steps=0) if adaptive_entropy_coefficient: self._state = self._state._replace( alpha_optimizer_state=alpha_optimizer_state, log_sac_alpha=log_sac_alpha) if adaptive_cql_coefficient: self._state = self._state._replace( cql_optimizer_state=cql_optimizer_state, log_cql_alpha=log_cql_alpha) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, networks, rng, policy_optimizer, q_optimizer, iterator, counter, logger, obs_to_goal, config): """Initialize the Contrastive RL learner. Args: networks: Contrastive RL networks. rng: a key for random number generation. policy_optimizer: the policy optimizer. q_optimizer: the Q-function optimizer. iterator: an iterator over training data. counter: counter object used to keep track of steps. logger: logger object to be used by learner. obs_to_goal: a function for extracting the goal coordinates. config: the experiment config file. """ if config.add_mc_to_td: assert config.use_td adaptive_entropy_coefficient = config.entropy_coefficient is None self._num_sgd_steps_per_step = config.num_sgd_steps_per_step self._obs_dim = config.obs_dim self._use_td = config.use_td if adaptive_entropy_coefficient: # alpha is the temperature parameter that determines the relative # importance of the entropy term versus the reward. log_alpha = jnp.asarray(0., dtype=jnp.float32) alpha_optimizer = optax.adam(learning_rate=3e-4) alpha_optimizer_state = alpha_optimizer.init(log_alpha) else: if config.target_entropy: raise ValueError('target_entropy should not be set when ' 'entropy_coefficient is provided') def alpha_loss(log_alpha, policy_params, transitions, key): """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - config.target_entropy) return jnp.mean(alpha_loss) def critic_loss(q_params, policy_params, target_q_params, transitions, key): batch_size = transitions.observation.shape[0] # Note: We might be able to speed up the computation for some of the # baselines to making a single network that returns all the values. This # avoids computing some of the underlying representations multiple times. if config.use_td: # For TD learning, the diagonal elements are the immediate next state. s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1) next_s, _ = jnp.split(transitions.next_observation, [config.obs_dim], axis=1) if config.add_mc_to_td: next_fraction = (1 - config.discount) / ( (1 - config.discount) + 1) num_next = int(batch_size * next_fraction) new_g = jnp.concatenate([ obs_to_goal(next_s[:num_next]), g[num_next:], ], axis=0) else: new_g = obs_to_goal(next_s) obs = jnp.concatenate([s, new_g], axis=1) transitions = transitions._replace(observation=obs) I = jnp.eye(batch_size) # pylint: disable=invalid-name logits = networks.q_network.apply(q_params, transitions.observation, transitions.action) if config.use_td: # Make sure to use the twin Q trick. assert len(logits.shape) == 3 # We evaluate the next-state Q function using random goals s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1) del s next_s = transitions.next_observation[:, :config.obs_dim] goal_indices = jnp.roll( jnp.arange(batch_size, dtype=jnp.int32), -1) g = g[goal_indices] transitions = transitions._replace( next_observation=jnp.concatenate([next_s, g], axis=1)) next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_action = networks.sample(next_dist_params, key) next_q = networks.q_network.apply( target_q_params, transitions.next_observation, next_action) # This outputs logits. next_q = jax.nn.sigmoid(next_q) next_v = jnp.min(next_q, axis=-1) next_v = jax.lax.stop_gradient(next_v) next_v = jnp.diag(next_v) # diag(logits) are predictions for future states. # diag(next_q) are predictions for random states, which correspond to # the predictions logits[range(B), goal_indices]. # So, the only thing that's meaningful for next_q is the diagonal. Off # diagonal entries are meaningless and shouldn't be used. w = next_v / (1 - next_v) w_clipping = 20.0 w = jnp.clip(w, 0, w_clipping) # (B, B, 2) --> (B, 2), computes diagonal of each twin Q. pos_logits = jax.vmap(jnp.diag, -1, -1)(logits) loss_pos = optax.sigmoid_binary_cross_entropy( logits=pos_logits, labels=1) # [B, 2] neg_logits = logits[jnp.arange(batch_size), goal_indices] loss_neg1 = w[:, None] * optax.sigmoid_binary_cross_entropy( logits=neg_logits, labels=1) # [B, 2] loss_neg2 = optax.sigmoid_binary_cross_entropy( logits=neg_logits, labels=0) # [B, 2] if config.add_mc_to_td: loss = ((1 + (1 - config.discount)) * loss_pos + config.discount * loss_neg1 + 2 * loss_neg2) else: loss = ((1 - config.discount) * loss_pos + config.discount * loss_neg1 + loss_neg2) # Take the mean here so that we can compute the accuracy. logits = jnp.mean(logits, axis=-1) else: # For the MC losses. def loss_fn(_logits): # pylint: disable=invalid-name if config.use_cpc: return (optax.softmax_cross_entropy(logits=_logits, labels=I) + 0.01 * jax.nn.logsumexp(_logits, axis=1)**2) else: return optax.sigmoid_binary_cross_entropy( logits=_logits, labels=I) if len(logits.shape) == 3: # twin q # loss.shape = [.., num_q] loss = jax.vmap(loss_fn, in_axes=2, out_axes=-1)(logits) loss = jnp.mean(loss, axis=-1) # Take the mean here so that we can compute the accuracy. logits = jnp.mean(logits, axis=-1) else: loss = loss_fn(logits) loss = jnp.mean(loss) correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1)) logits_pos = jnp.sum(logits * I) / jnp.sum(I) logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I) if len(logits.shape) == 3: logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2 else: logsumexp = jax.nn.logsumexp(logits, axis=1)**2 metrics = { 'binary_accuracy': jnp.mean((logits > 0) == I), 'categorical_accuracy': jnp.mean(correct), 'logits_pos': logits_pos, 'logits_neg': logits_neg, 'logsumexp': logsumexp.mean(), } return loss, metrics def actor_loss( policy_params, q_params, alpha, transitions, key, ): obs = transitions.observation if config.use_gcbc: dist_params = networks.policy_network.apply(policy_params, obs) log_prob = networks.log_prob(dist_params, transitions.action) actor_loss = -1.0 * jnp.mean(log_prob) else: state = obs[:, :config.obs_dim] goal = obs[:, config.obs_dim:] if config.random_goals == 0.0: new_state = state new_goal = goal elif config.random_goals == 0.5: new_state = jnp.concatenate([state, state], axis=0) new_goal = jnp.concatenate( [goal, jnp.roll(goal, 1, axis=0)], axis=0) else: assert config.random_goals == 1.0 new_state = state new_goal = jnp.roll(goal, 1, axis=0) new_obs = jnp.concatenate([new_state, new_goal], axis=1) dist_params = networks.policy_network.apply( policy_params, new_obs) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.q_network.apply(q_params, new_obs, action) if len(q_action.shape) == 3: # twin q trick assert q_action.shape[2] == 2 q_action = jnp.mean(q_action, axis=-1) actor_loss = alpha * log_prob - jnp.diag(q_action) return jnp.mean(actor_loss) alpha_grad = jax.value_and_grad(alpha_loss) critic_grad = jax.value_and_grad(critic_loss, has_aux=True) actor_grad = jax.value_and_grad(actor_loss) def update_step( state, transitions, ): key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha_loss, alpha_grads = alpha_grad(state.alpha_params, state.policy_params, transitions, key_alpha) alpha = jnp.exp(state.alpha_params) else: alpha = config.entropy_coefficient if not config.use_gcbc: (critic_loss, critic_metrics), critic_grads = critic_grad( state.q_params, state.policy_params, state.target_q_params, transitions, key_critic) actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params, alpha, transitions, key_actor) # Apply policy gradients actor_update, policy_optimizer_state = policy_optimizer.update( actor_grads, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, actor_update) # Apply critic gradients if config.use_gcbc: metrics = {} critic_loss = 0.0 q_params = state.q_params q_optimizer_state = state.q_optimizer_state new_target_q_params = state.target_q_params else: critic_update, q_optimizer_state = q_optimizer.update( critic_grads, state.q_optimizer_state) q_params = optax.apply_updates(state.q_params, critic_update) new_target_q_params = jax.tree_multimap( lambda x, y: x * (1 - config.tau) + y * config.tau, state.target_q_params, q_params) metrics = critic_metrics metrics.update({ 'critic_loss': critic_loss, 'actor_loss': actor_loss, }) new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=new_target_q_params, key=key, ) if adaptive_entropy_coefficient: # Apply alpha gradients alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) alpha_params = optax.apply_updates(state.alpha_params, alpha_update) metrics.update({ 'alpha_loss': alpha_loss, 'alpha': jnp.exp(alpha_params), }) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=alpha_params) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, time_delta=10.0) # Iterator on demonstration transitions. self._iterator = iterator update_step = utils.process_multiple_batches( update_step, config.num_sgd_steps_per_step) # Use the JIT compiler. if config.jit: self._update_step = jax.jit(update_step) else: self._update_step = update_step def make_initial_state(key): """Initialises the training state (parameters and optimiser state).""" key_policy, key_q, key = jax.random.split(key, 3) policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) q_params = networks.q_network.init(key_q) q_optimizer_state = q_optimizer.init(q_params) state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=q_params, key=key) if adaptive_entropy_coefficient: state = state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha) return state # Create initial state. self._state = make_initial_state(rng) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online # and fill the replay buffer. self._timestamp = None
def __init__(self, counter: counting.Counter, direct_rl_learner_factory: Callable[ [Iterator[reverb.ReplaySample]], acme.Learner], loss_fn: losses.Loss, iterator: Iterator[AILSample], discriminator_optimizer: optax.GradientTransformation, ail_network: ail_networks.AILNetworks, discriminator_key: networks_lib.PRNGKey, is_sequence_based: bool, num_sgd_steps_per_step: int = 1, policy_variable_name: Optional[str] = None, logger: Optional[loggers.Logger] = None): """AIL Learner. Args: counter: Counter. direct_rl_learner_factory: Function that creates the direct RL learner when passed a replay sample iterator. loss_fn: Discriminator loss. iterator: Iterator that returns AILSamples. discriminator_optimizer: Discriminator optax optimizer. ail_network: AIL networks. discriminator_key: RNG key. is_sequence_based: If True, a direct rl algorithm is using SequenceAdder data format. Otherwise the learner assumes that the direct rl algorithm is using NStepTransitionAdder. num_sgd_steps_per_step: Number of discriminator gradient updates per step. policy_variable_name: The name of the policy variable to retrieve direct_rl policy parameters. logger: Logger. """ self._is_sequence_based = is_sequence_based state_key, networks_key = jax.random.split(discriminator_key) # Generator expression that works the same as an iterator. # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions iterator, direct_rl_iterator = itertools.tee(iterator) direct_rl_iterator = (self._process_sample(sample.direct_sample) for sample in direct_rl_iterator) self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator) self._iterator = iterator if policy_variable_name is not None: def get_policy_params(): return self._direct_rl_learner.get_variables( [policy_variable_name])[0] self._get_policy_params = get_policy_params else: self._get_policy_params = lambda: None # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Use the JIT compiler. self._update_step = functools.partial( ail_update_step, optimizer=discriminator_optimizer, ail_network=ail_network, loss_fn=loss_fn) self._update_step = utils.process_multiple_batches( self._update_step, num_sgd_steps_per_step) self._update_step = jax.jit(self._update_step) discriminator_params, discriminator_state = ( ail_network.discriminator_network.init(networks_key)) self._state = DiscriminatorTrainingState( optimizer_state=discriminator_optimizer.init(discriminator_params), discriminator_params=discriminator_params, discriminator_state=discriminator_state, policy_params=self._get_policy_params(), key=state_key, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None self._get_reward = jax.jit( functools.partial(ail_networks.compute_ail_reward, networks=ail_network))
def __init__(self, policy_network: networks_lib.FeedForwardNetwork, critic_network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, discount: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], policy_optimizer: Optional[optax.GradientTransformation] = None, critic_optimizer: Optional[optax.GradientTransformation] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, jit: bool = True, num_sgd_steps_per_step: int = 1): def critic_mean( critic_params: networks_lib.Params, observation: types.NestedArray, action: types.NestedArray, ) -> jnp.ndarray: # We add batch dimension to make sure batch concat in critic_network # works correctly. observation = utils.add_batch_dim(observation) action = utils.add_batch_dim(action) # Computes the mean action-value estimate. logits, atoms = critic_network.apply(critic_params, observation, action) logits = utils.squeeze_batch_dim(logits) probabilities = jax.nn.softmax(logits) return jnp.sum(probabilities * atoms, axis=-1) def policy_loss( policy_params: networks_lib.Params, critic_params: networks_lib.Params, o_t: types.NestedArray, ) -> jnp.ndarray: # Computes the discrete policy gradient loss. dpg_a_t = policy_network.apply(policy_params, o_t) grad_critic = jax.vmap( jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0)) dq_da = grad_critic(critic_params, o_t, dpg_a_t) dqda_clipping = 1. if clipping else None batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None)) loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping) return jnp.mean(loss) def critic_loss( critic_params: networks_lib.Params, state: TrainingState, transition: types.Transition, ): # Computes the distributional critic loss. q_tm1, atoms_tm1 = critic_network.apply(critic_params, transition.observation, transition.action) a = policy_network.apply(state.target_policy_params, transition.next_observation) q_t, atoms_t = critic_network.apply(state.target_critic_params, transition.next_observation, a) batch_td_learning = jax.vmap( rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0)) loss = batch_td_learning(atoms_tm1, q_tm1, transition.reward, discount * transition.discount, atoms_t, q_t) return jnp.mean(loss) def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: # TODO(jaslanides): Use a shared forward pass for efficiency. policy_loss_and_grad = jax.value_and_grad(policy_loss) critic_loss_and_grad = jax.value_and_grad(critic_loss) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions.next_observation) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( # pytype: disable=attribute-error policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( # pytype: disable=attribute-error critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = optax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, self._target_update_period) new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Necessary to track when to update target networks. self._target_update_period = target_update_period # Create prefetching dataset iterator. self._iterator = iterator # Maybe use the JIT compiler. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) self._sgd_step = jax.jit(sgd_step) if jit else sgd_step # Create the network parameters and copy into the target network parameters. key_policy, key_critic = jax.random.split(random_key) initial_policy_params = policy_network.init(key_policy) initial_critic_params = critic_network.init(key_critic) initial_target_policy_params = initial_policy_params initial_target_critic_params = initial_critic_params # Create optimizers if they aren't given. critic_optimizer = critic_optimizer or optax.adam(1e-4) policy_optimizer = policy_optimizer or optax.adam(1e-4) # Initialize optimizers. initial_policy_opt_state = policy_optimizer.init(initial_policy_params) # pytype: disable=attribute-error initial_critic_opt_state = critic_optimizer.init(initial_critic_params) # pytype: disable=attribute-error # Create initial state. self._state = TrainingState( policy_params=initial_policy_params, target_policy_params=initial_target_policy_params, critic_params=initial_critic_params, target_critic_params=initial_target_critic_params, policy_opt_state=initial_policy_opt_state, critic_opt_state=initial_critic_opt_state, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, loss_fn: losses.Loss, optimizer: optax.GradientTransformation, prefetching_iterator: Iterator[types.Transition], num_sgd_steps_per_step: int, loss_has_aux: bool = False, logger: Optional[loggers.Logger] = None, counter: Optional[counting.Counter] = None): """Behavior Cloning Learner. Args: network: Networks with signature for apply: (params, obs, is_training, key) -> jnp.ndarray and for init: (rng, is_training) -> params random_key: RNG key. loss_fn: BC loss to use. optimizer: Optax optimizer. prefetching_iterator: A sharded prefetching iterator as outputted from `acme.jax.utils.sharded_prefetch`. Please see the documentation for `sharded_prefetch` for more details. num_sgd_steps_per_step: Number of gradient updates per step. loss_has_aux: Whether the loss function returns auxiliary metrics as a second argument. logger: Logger. counter: Counter. """ def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1, has_aux=loss_has_aux) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_result, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) # Combine the gradient across all devices (by taking their mean). gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME) # Compute and combine metrics across all devices. metrics = _create_loss_metrics(loss_has_aux, loss_result, gradients) metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME) policy_update, optimizer_state = optimizer.update( gradients, state.optimizer_state, state.policy_params) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='learner') self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) random_key, init_key = jax.random.split(random_key) policy_params = network.init(init_key) optimizer_state = optimizer.init(policy_params) # Create initial state. state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=random_key, steps=0, ) self._state = utils.replicate_in_all_devices(state) self._timestamp = None self._prefetching_iterator = prefetching_iterator
# General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._iterator = iterator # update_step = utils.process_multiple_batches(update_step, # num_sgd_steps_per_step) # # Use the JIT compiler. # self._update_step = jax.jit(update_step) self._update_step_in_initial_bc_iters = jax.jit( utils.process_multiple_batches( lambda x, y: _full_update_step(x, y, True), num_sgd_steps_per_step)) self._update_step_rest = jax.jit( utils.process_multiple_batches( lambda x, y: _full_update_step(x, y, False), num_sgd_steps_per_step)) def make_initial_state(key): """Initialises the training state (parameters and optimiser state).""" key_policy, key_q, key = jax.random.split(key, 3) policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) q_params = networks.q_network.init(key_q) q_optimizer_state = q_optimizer.init(q_params)
def __init__(self, networks: sac_networks.SACNetworks, rng: jnp.ndarray, iterator: Iterator[reverb.ReplaySample], policy_optimizer: optax.GradientTransformation, q_optimizer: optax.GradientTransformation, tau: float = 0.005, reward_scale: float = 1.0, discount: float = 0.99, entropy_coefficient: Optional[float] = None, target_entropy: float = 0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initialize the SAC learner. Args: networks: SAC networks rng: a key for random number generation. iterator: an iterator over training data. policy_optimizer: the policy optimizer. q_optimizer: the Q-function optimizer. tau: target smoothing coefficient. reward_scale: reward scale. discount: discount to use for TD updates. entropy_coefficient: coefficient applied to the entropy bonus. If None, an adaptative coefficient will be used. target_entropy: Used to normalize entropy. Only used when entropy_coefficient is None. counter: counter object used to keep track of steps. logger: logger object to be used by learner. num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ adaptive_entropy_coefficient = entropy_coefficient is None if adaptive_entropy_coefficient: # alpha is the temperature parameter that determines the relative # importance of the entropy term versus the reward. log_alpha = jnp.asarray(0., dtype=jnp.float32) alpha_optimizer = optax.adam(learning_rate=3e-4) alpha_optimizer_state = alpha_optimizer.init(log_alpha) else: if target_entropy: raise ValueError('target_entropy should not be set when ' 'entropy_coefficient is provided') def alpha_loss(log_alpha: jnp.ndarray, policy_params: networks_lib.Params, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) return jnp.mean(alpha_loss) def critic_loss(q_params: networks_lib.Params, policy_params: networks_lib.Params, target_q_params: networks_lib.Params, alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: q_old_action = networks.q_network.apply(q_params, transitions.observation, transitions.action) next_dist_params = networks.policy_network.apply( policy_params, transitions.next_observation) next_action = networks.sample(next_dist_params, key) next_log_prob = networks.log_prob(next_dist_params, next_action) next_q = networks.q_network.apply(target_q_params, transitions.next_observation, next_action) next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob target_q = jax.lax.stop_gradient( transitions.reward * reward_scale + transitions.discount * discount * next_v) q_error = q_old_action - jnp.expand_dims(target_q, -1) q_loss = 0.5 * jnp.mean(jnp.square(q_error)) return q_loss def actor_loss(policy_params: networks_lib.Params, q_params: networks_lib.Params, alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.q_network.apply(q_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - min_q return jnp.mean(actor_loss) alpha_grad = jax.value_and_grad(alpha_loss) critic_grad = jax.value_and_grad(critic_loss) actor_grad = jax.value_and_grad(actor_loss) def update_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha_loss, alpha_grads = alpha_grad(state.alpha_params, state.policy_params, transitions, key_alpha) alpha = jnp.exp(state.alpha_params) else: alpha = entropy_coefficient critic_loss, critic_grads = critic_grad(state.q_params, state.policy_params, state.target_q_params, alpha, transitions, key_critic) actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params, alpha, transitions, key_actor) # Apply policy gradients actor_update, policy_optimizer_state = policy_optimizer.update( actor_grads, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, actor_update) # Apply critic gradients critic_update, q_optimizer_state = q_optimizer.update( critic_grads, state.q_optimizer_state) q_params = optax.apply_updates(state.q_params, critic_update) new_target_q_params = jax.tree_map( lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params) metrics = { 'critic_loss': critic_loss, 'actor_loss': actor_loss, } new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=new_target_q_params, key=key, ) if adaptive_entropy_coefficient: # Apply alpha gradients alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) alpha_params = optax.apply_updates(state.alpha_params, alpha_update) metrics.update({ 'alpha_loss': alpha_loss, 'alpha': jnp.exp(alpha_params), }) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=alpha_params) metrics['rewards_mean'] = jnp.mean( jnp.abs(jnp.mean(transitions.reward, axis=0))) metrics['rewards_std'] = jnp.std(transitions.reward, axis=0) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Iterator on demonstration transitions. self._iterator = iterator update_step = utils.process_multiple_batches(update_step, num_sgd_steps_per_step) # Use the JIT compiler. self._update_step = jax.jit(update_step) def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" key_policy, key_q, key = jax.random.split(key, 3) policy_params = networks.policy_network.init(key_policy) policy_optimizer_state = policy_optimizer.init(policy_params) q_params = networks.q_network.init(key_q) q_optimizer_state = q_optimizer.init(q_params) state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=q_params, key=key) if adaptive_entropy_coefficient: state = state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha) return state # Create initial state. self._state = make_initial_state(rng) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, networks, rng, iterator, policy_lr = 1e-4, loss_type = 'MLE', # or MSE regularize_entropy = False, entropy_regularization_weight = 1.0, use_img_encoder = False, img_encoder_params_ckpt_path = '', counter = None, logger = None, num_sgd_steps_per_step = 1): """Initialize the BC learner. Args: networks: BC networks rng: a key for random number generation. iterator: an iterator over training data. policy_lr: learning rate for the policy regularize_entropy: whether to regularize the entropy of the policy. entropy_regularization_weight: weight for entropy regularization. use_img_encoder: whether to preprocess the image part of the observation using a pretrained encoder. img_encoder_params_ckpt_path: path to checkpoint for image encoder params counter: counter object used to keep track of steps. logger: logger object to be used by learner. num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ assert loss_type in ['MLE', 'MSE'], 'Invalid BC loss type!' num_devices = len(jax.devices()) self._num_sgd_steps_per_step = num_sgd_steps_per_step self._use_img_encoder = use_img_encoder policy_optimizer = optax.adam(learning_rate=policy_lr) def actor_loss( policy_params, transitions, key, img_encoder_params): obs = transitions.observation acts = transitions.action if use_img_encoder: img = obs['state_image'] dense = obs['state_dense'] obs = dict( state_image=networks.img_encoder.apply(img_encoder_params, img), state_dense=dense,) dist = networks.policy_network.apply(policy_params, obs) if loss_type == 'MLE': log_probs = networks.log_prob(dist, acts) loss = -1. * jnp.mean(log_probs) else: acts_mode = dist.mode() mse = jnp.sum((acts_mode - acts)**2, axis=-1) loss = 0.5 * jnp.mean(mse) total_loss = loss entropy_term = 0. if regularize_entropy: sample_acts = networks.sample(dist, key) sample_log_probs = networks.log_prob(dist, sample_acts) entropy_term = jnp.mean(sample_log_probs) total_loss = total_loss + entropy_regularization_weight * entropy_term return total_loss, (loss, entropy_term) actor_loss_and_grad = jax.value_and_grad(actor_loss, has_aux=True) def actor_update_step( policy_params, optim_state, transitions, key, img_encoder_params): (total_loss, (bc_loss_term, entropy_term)), actor_grad = actor_loss_and_grad( policy_params, transitions, key, img_encoder_params) actor_grad = jax.lax.pmean(actor_grad, 'across_devices') policy_update, optim_state = policy_optimizer.update(actor_grad, optim_state) policy_params = optax.apply_updates(policy_params, policy_update) return policy_params, optim_state, total_loss, bc_loss_term, entropy_term pmapped_actor_update_step = jax.pmap( actor_update_step, axis_name='across_devices', in_axes=0, out_axes=0) def _full_update_step( state, transitions, ): """The unjitted version of the full update step.""" metrics = OrderedDict() key = state.key # actor update step def reshape_for_devices(t): rest_t_shape = list(t.shape[1:]) new_shape = [num_devices, t.shape[0]//num_devices,] + rest_t_shape return jnp.reshape(t, new_shape) transitions = jax.tree_map(reshape_for_devices, transitions) sub_keys = jax.random.split(key, num_devices + 1) key = sub_keys[0] sub_keys = sub_keys[1:] new_policy_params, new_policy_optimizer_state, total_loss, bc_loss_term, entropy_term = pmapped_actor_update_step( state.policy_params, state.policy_optimizer_state, transitions, sub_keys, state.img_encoder_params) metrics['total_actor_loss'] = jnp.mean(total_loss) metrics['BC_loss'] = jnp.mean(bc_loss_term) metrics['entropy_loss'] = jnp.mean(entropy_term) # create new state new_state = TrainingState( policy_optimizer_state=new_policy_optimizer_state, policy_params=new_policy_params, key=key, img_encoder_params=state.img_encoder_params) return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._iterator = iterator self._update_step = utils.process_multiple_batches( _full_update_step, num_sgd_steps_per_step) def make_initial_state(key): """""" # policy stuff key, sub_key = jax.random.split(key) policy_params = networks.policy_network.init(sub_key) policy_optimizer_state = policy_optimizer.init(policy_params) devices = jax.local_devices() replicated_policy_params = jax.device_put_replicated( policy_params, devices) replicated_optim_state = jax.device_put_replicated( policy_optimizer_state, devices) if use_img_encoder: """ Load pretrained img_encoder_params and do: replicated_img_encoder_params = jax.device_put_replicated( img_encoder_params, devices) """ class EncoderTrainingState(NamedTuple): encoder_params: hk.Params img_encoder_params = {} replicated_img_encoder_params = img_encoder_params raise NotImplementedError('Need to load a checkpoint.') else: img_encoder_params = {} replicated_img_encoder_params = img_encoder_params state = TrainingState( policy_optimizer_state=replicated_optim_state, policy_params=replicated_policy_params, key=key, img_encoder_params=replicated_img_encoder_params) return state # Create initial state. self._state = make_initial_state(rng) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network, random_key, temperature, num_actions, optimizer, demonstrations, num_sgd_steps_per_step, logger = None, counter = None): def aqualoss(params, transitions, key): predicted_actions = network.apply( params, transitions.observation, is_training=True, rngs={'dropout': key}) predicted_actions = jnp.squeeze(predicted_actions) action_distances = jnp.sum( jnp.square(predicted_actions - jnp.expand_dims(transitions.action, axis=-1)), axis=0) # softmin softmin_action_distances = temperature * ( jax.nn.logsumexp(-action_distances / temperature) - jnp.log(num_actions) ) loss = - softmin_action_distances return loss def batch_aqualoss(params, transitions, key): batched_aqualoss = jax.vmap(aqualoss, in_axes=(None, 0, None), out_axes=0) return jnp.mean(batched_aqualoss(params, transitions, key)) def sgd_step( state, transitions, ): loss_and_grad = jax.value_and_grad(batch_aqualoss, argnums=0) # Compute losses and their gradients. loss_key, random_key = jax.random.split(state.random_key) loss_value, gradients = loss_and_grad(state.encoder_params, transitions, loss_key) update, optimizer_state = optimizer.update( gradients, state.optimizer_state, params=state.encoder_params) encoder_params = optax.apply_updates(state.encoder_params, update) new_state = PretrainingState( optimizer_state=optimizer_state, encoder_params=encoder_params, random_key=random_key, steps=state.steps + 1, ) metrics = { 'encoder_loss': loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='encoder') self._logger = logger or loggers.make_default_logger( 'encoder', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._demonstrations = demonstrations # Use the JIT compiler. self._sgd_step = utils.process_multiple_batches( sgd_step, num_sgd_steps_per_step) self._sgd_step = jax.jit(self._sgd_step) self._num_actions = num_actions encoder_params = network.init(random_key) optimizer_state = optimizer.init(encoder_params) # Create initial state. self._state = PretrainingState( optimizer_state=optimizer_state, encoder_params=encoder_params, random_key=random_key, steps=0, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, networks: CRRNetworks, random_key: networks_lib.PRNGKey, discount: float, target_update_period: int, policy_loss_coeff_fn: PolicyLossCoeff, iterator: Iterator[types.Transition], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, grad_updates_per_batch: int = 1, use_sarsa_target: bool = False): """Initializes the CRR learner. Args: networks: CRR networks. random_key: a key for random number generation. discount: discount to use for TD updates. target_update_period: period to update target's parameters. policy_loss_coeff_fn: set the loss function for the policy. iterator: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. counter: counter object used to keep track of steps. logger: logger object to be used by learner. grad_updates_per_batch: how many gradient updates given a sampled batch. use_sarsa_target: compute on-policy target using iterator's actions rather than sampled actions. Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). When set to `True`, `target_policy_params` are unused. """ critic_network = networks.critic_network policy_network = networks.policy_network def policy_loss( policy_params: networks_lib.Params, critic_params: networks_lib.Params, transition: types.Transition, key: networks_lib.PRNGKey, ) -> jnp.ndarray: # Compute the loss coefficients. coeff = policy_loss_coeff_fn(networks, policy_params, critic_params, transition, key) coeff = jax.lax.stop_gradient(coeff) # Return the weighted loss. dist_params = policy_network.apply(policy_params, transition.observation) logp_action = networks.log_prob(dist_params, transition.action) return -jnp.mean(logp_action * coeff) def critic_loss( critic_params: networks_lib.Params, target_policy_params: networks_lib.Params, target_critic_params: networks_lib.Params, transition: types.Transition, key: networks_lib.PRNGKey, ): # Sample the next action. if use_sarsa_target: # TODO(b/222674779): use N-steps Trajectories to get the next actions. assert 'next_action' in transition.extras, ( 'next actions should be given as extras for one step RL.') next_action = transition.extras['next_action'] else: next_dist_params = policy_network.apply( target_policy_params, transition.next_observation) next_action = networks.sample(next_dist_params, key) # Calculate the value of the next state and action. next_q = critic_network.apply(target_critic_params, transition.next_observation, next_action) target_q = transition.reward + transition.discount * discount * next_q target_q = jax.lax.stop_gradient(target_q) q = critic_network.apply(critic_params, transition.observation, transition.action) q_error = q - target_q # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. # TODO(sertan): Replace with a distributional critic. CRR paper states # that this may perform better. return 0.5 * jnp.mean(jnp.square(q_error)) policy_loss_and_grad = jax.value_and_grad(policy_loss) critic_loss_and_grad = jax.value_and_grad(critic_loss) def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_policy, key_critic = jax.random.split(state.key, 3) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions, key_policy) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state.target_policy_params, state.target_critic_params, transitions, key_critic) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = optax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, target_update_period) new_state = TrainingState( policy_params=policy_params, target_policy_params=target_policy_params, critic_params=critic_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, key=key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics sgd_step = utils.process_multiple_batches(sgd_step, grad_updates_per_batch) self._sgd_step = jax.jit(sgd_step) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Create prefetching dataset iterator. self._iterator = iterator # Create the network parameters and copy into the target network parameters. key, key_policy, key_critic = jax.random.split(random_key, 3) initial_policy_params = policy_network.init(key_policy) initial_critic_params = critic_network.init(key_critic) initial_target_policy_params = initial_policy_params initial_target_critic_params = initial_critic_params # Initialize optimizers. initial_policy_opt_state = policy_optimizer.init(initial_policy_params) initial_critic_opt_state = critic_optimizer.init(initial_critic_params) # Create initial state. self._state = TrainingState( policy_params=initial_policy_params, target_policy_params=initial_target_policy_params, critic_params=initial_critic_params, target_critic_params=initial_target_critic_params, policy_opt_state=initial_policy_opt_state, critic_opt_state=initial_critic_opt_state, steps=0, key=key, ) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: networks_lib.FeedForwardNetwork, random_key: networks_lib.PRNGKey, loss_fn: losses.Loss, optimizer: optax.GradientTransformation, demonstrations: Iterator[types.Transition], num_sgd_steps_per_step: int, logger: Optional[loggers.Logger] = None, counter: Optional[counting.Counter] = None): def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_value, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) policy_update, optimizer_state = optimizer.update(gradients, state.optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) metrics = { 'loss': loss_value, 'gradient_norm': optax.global_norm(gradients) } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter(prefix='learner') self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray) # Iterator on demonstration transitions. self._demonstrations = demonstrations # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. self._sgd_step = jax.jit(utils.process_multiple_batches( sgd_step, num_sgd_steps_per_step)) random_key, init_key = jax.random.split(random_key) policy_params = network.init(init_key) optimizer_state = optimizer.init(policy_params) # Create initial state. self._state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=random_key, steps=0, ) self._timestamp = None
def __init__(self, networks: td3_networks.TD3Networks, random_key: networks_lib.PRNGKey, discount: float, iterator: Iterator[reverb.ReplaySample], policy_optimizer: optax.GradientTransformation, critic_optimizer: optax.GradientTransformation, twin_critic_optimizer: optax.GradientTransformation, delay: int = 2, target_sigma: float = 0.2, noise_clip: float = 0.5, tau: float = 0.005, use_sarsa_target: bool = False, bc_alpha: Optional[float] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initializes the TD3 learner. Args: networks: TD3 networks. random_key: a key for random number generation. discount: discount to use for TD updates iterator: an iterator over training data. policy_optimizer: the policy optimizer. critic_optimizer: the Q-function optimizer. twin_critic_optimizer: the twin Q-function optimizer. delay: ratio of policy updates for critic updates (see TD3), delay=2 means 2 updates of the critic for 1 policy update. target_sigma: std of zero mean Gaussian added to the action of the next_state, for critic evaluation (reducing overestimation bias). noise_clip: hard constraint on target noise. tau: target parameters smoothing coefficient. use_sarsa_target: compute on-policy target using iterator's actions rather than sampled actions. Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf). When set to `True`, `target_policy_params` are unused. This is only working when the learner is used as an offline algorithm. I.e. TD3Builder does not support adding the SARSA target to the replay buffer. bc_alpha: bc_alpha: Implements TD3+BC. See comments in TD3Config.bc_alpha for details. counter: counter object used to keep track of steps. logger: logger object to be used by learner. num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. """ def policy_loss( policy_params: networks_lib.Params, critic_params: networks_lib.Params, transition: types.NestedArray, ) -> jnp.ndarray: # Computes the discrete policy gradient loss. action = networks.policy_network.apply(policy_params, transition.observation) grad_critic = jax.vmap(jax.grad(networks.critic_network.apply, argnums=2), in_axes=(None, 0, 0)) dq_da = grad_critic(critic_params, transition.observation, action) batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0)) loss = jnp.mean(batch_dpg_learning(action, dq_da)) if bc_alpha is not None: # BC regularization for offline RL q_sa = networks.critic_network.apply(critic_params, transition.observation, action) bc_factor = jax.lax.stop_gradient(bc_alpha / jnp.mean(jnp.abs(q_sa))) loss += jnp.mean( jnp.square(action - transition.action)) / bc_factor return loss def critic_loss( critic_params: networks_lib.Params, state: TrainingState, transition: types.Transition, random_key: jnp.ndarray, ): # Computes the critic loss. q_tm1 = networks.critic_network.apply(critic_params, transition.observation, transition.action) if use_sarsa_target: # TODO(b/222674779): use N-steps Trajectories to get the next actions. assert 'next_action' in transition.extras, ( 'next actions should be given as extras for one step RL.') action = transition.extras['next_action'] else: action = networks.policy_network.apply( state.target_policy_params, transition.next_observation) action = networks.add_policy_noise(action, random_key, target_sigma, noise_clip) q_t = networks.critic_network.apply(state.target_critic_params, transition.next_observation, action) twin_q_t = networks.twin_critic_network.apply( state.target_twin_critic_params, transition.next_observation, action) q_t = jnp.minimum(q_t, twin_q_t) target_q_tm1 = transition.reward + discount * transition.discount * q_t td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1 return jnp.mean(jnp.square(td_error)) def update_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: random_key, key_critic, key_twin = jax.random.split( state.random_key, 3) # Updates on the critic: compute the gradients, and update using # Polyak averaging. critic_loss_and_grad = jax.value_and_grad(critic_loss) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions, key_critic) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) critic_params = optax.apply_updates(state.critic_params, critic_updates) # In the original authors' implementation the critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_critic_params = optax.incremental_update( new_tensors=critic_params, old_tensors=state.target_critic_params, step_size=tau) # Updates on the twin critic: compute the gradients, and update using # Polyak averaging. twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( state.twin_critic_params, state, transitions, key_twin) twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( twin_critic_gradients, state.twin_critic_opt_state) twin_critic_params = optax.apply_updates(state.twin_critic_params, twin_critic_updates) # In the original authors' implementation the twin critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_twin_critic_params = optax.incremental_update( new_tensors=twin_critic_params, old_tensors=state.target_twin_critic_params, step_size=tau) # Updates on the policy: compute the gradients, and update using # Polyak averaging (if delay enabled, the update might not be applied). policy_loss_and_grad = jax.value_and_grad(policy_loss) policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions) def update_policy_step(): policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) policy_params = optax.apply_updates(state.policy_params, policy_updates) target_policy_params = optax.incremental_update( new_tensors=policy_params, old_tensors=state.target_policy_params, step_size=tau) return policy_params, target_policy_params, policy_opt_state # The update on the policy is applied every `delay` steps. current_policy_state = (state.policy_params, state.target_policy_params, state.policy_opt_state) policy_params, target_policy_params, policy_opt_state = jax.lax.cond( state.steps % delay == 0, lambda _: update_policy_step(), lambda _: current_policy_state, operand=None) steps = state.steps + 1 new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, twin_critic_params=twin_critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, target_twin_critic_params=target_twin_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, twin_critic_opt_state=twin_critic_opt_state, steps=steps, random_key=random_key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, 'twin_critic_loss': twin_critic_loss_value, } return new_state, metrics # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Create prefetching dataset iterator. self._iterator = iterator # Faster sgd step update_step = utils.process_multiple_batches(update_step, num_sgd_steps_per_step) # Use the JIT compiler. self._update_step = jax.jit(update_step) (key_init_policy, key_init_twin, key_init_target, key_state) = jax.random.split(random_key, 4) # Create the network parameters and copy into the target network parameters. initial_policy_params = networks.policy_network.init(key_init_policy) initial_critic_params = networks.critic_network.init(key_init_twin) initial_twin_critic_params = networks.twin_critic_network.init( key_init_target) initial_target_policy_params = initial_policy_params initial_target_critic_params = initial_critic_params initial_target_twin_critic_params = initial_twin_critic_params # Initialize optimizers. initial_policy_opt_state = policy_optimizer.init(initial_policy_params) initial_critic_opt_state = critic_optimizer.init(initial_critic_params) initial_twin_critic_opt_state = twin_critic_optimizer.init( initial_twin_critic_params) # Create initial state. self._state = TrainingState( policy_params=initial_policy_params, target_policy_params=initial_target_policy_params, critic_params=initial_critic_params, twin_critic_params=initial_twin_critic_params, target_critic_params=initial_target_critic_params, target_twin_critic_params=initial_target_twin_critic_params, policy_opt_state=initial_policy_opt_state, critic_opt_state=initial_critic_opt_state, twin_critic_opt_state=initial_twin_critic_opt_state, steps=0, random_key=key_state) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None