def gradients_fn(policy: Policy, optimizer: LocalOptimizer, loss: TensorType) -> ModelGradients: if policy.config["framework"] in ["tf2", "tfe"]: tape = optimizer.tape pol_weights = policy.model.policy_variables() actor_grads_and_vars = list( zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights)) q_weights = policy.model.q_variables() critic_grads_and_vars = list( zip(tape.gradient(policy.critic_loss, q_weights), q_weights)) else: actor_grads_and_vars = policy._actor_optimizer.compute_gradients( policy.actor_loss, var_list=policy.model.policy_variables()) critic_grads_and_vars = policy._critic_optimizer.compute_gradients( policy.critic_loss, var_list=policy.model.q_variables()) # Clip if necessary. if policy.config["grad_clip"]: clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) else: clip_func = tf.identity # Save grads and vars for later use in `build_apply_op`. policy._actor_grads_and_vars = [(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None] policy._critic_grads_and_vars = [(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None] grads_and_vars = policy._actor_grads_and_vars + policy._critic_grads_and_vars return grads_and_vars
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, loss: TensorType) -> ModelGradients: """Gradients computing function (from loss tensor, using local optimizer). Note: For SAC, optimizer and loss are ignored b/c we have 3 losses and 3 local optimizers (all stored in policy). `optimizer` will be used, though, in the tf-eager case b/c it is then a fake optimizer (OptimizerWrapper) object with a `tape` property to generate a GradientTape object for gradient recording. Args: policy (Policy): The Policy object that generated the loss tensor and that holds the given local optimizer. optimizer (LocalOptimizer): The tf (local) optimizer object to calculate the gradients with. loss (TensorType): The loss tensor for which gradients should be calculated. Returns: ModelGradients: List of the possibly clipped gradients- and variable tuples. """ # Eager: Use GradientTape (which is a property of the `optimizer` object # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). if policy.config["framework"] in ["tf2", "tfe"]: tape = optimizer.tape pol_weights = policy.model.policy_variables() actor_grads_and_vars = list( zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights)) q_weights = policy.model.q_variables() if policy.config["twin_q"]: half_cutoff = len(q_weights) // 2 grads_1 = tape.gradient(policy.critic_loss[0], q_weights[:half_cutoff]) grads_2 = tape.gradient(policy.critic_loss[1], q_weights[half_cutoff:]) critic_grads_and_vars = \ list(zip(grads_1, q_weights[:half_cutoff])) + \ list(zip(grads_2, q_weights[half_cutoff:])) else: critic_grads_and_vars = list( zip(tape.gradient(policy.critic_loss[0], q_weights), q_weights)) alpha_vars = [policy.model.log_alpha] alpha_grads_and_vars = list( zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars)) # Tf1.x: Use optimizer.compute_gradients() else: actor_grads_and_vars = policy._actor_optimizer.compute_gradients( policy.actor_loss, var_list=policy.model.policy_variables()) q_weights = policy.model.q_variables() if policy.config["twin_q"]: half_cutoff = len(q_weights) // 2 base_q_optimizer, twin_q_optimizer = policy._critic_optimizer critic_grads_and_vars = base_q_optimizer.compute_gradients( policy.critic_loss[0], var_list=q_weights[:half_cutoff] ) + twin_q_optimizer.compute_gradients( policy.critic_loss[1], var_list=q_weights[half_cutoff:]) else: critic_grads_and_vars = policy._critic_optimizer[ 0].compute_gradients(policy.critic_loss[0], var_list=q_weights) alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients( policy.alpha_loss, var_list=[policy.model.log_alpha]) # Clip if necessary. if policy.config["grad_clip"]: clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) else: clip_func = tf.identity # Save grads and vars for later use in `build_apply_op`. policy._actor_grads_and_vars = [(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None] policy._critic_grads_and_vars = [(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None] policy._alpha_grads_and_vars = [(clip_func(g), v) for (g, v) in alpha_grads_and_vars if g is not None] grads_and_vars = (policy._actor_grads_and_vars + policy._critic_grads_and_vars + policy._alpha_grads_and_vars) return grads_and_vars