def my_apply_grad_clipping(policy, optimizer, loss):
    # Apply the gradient clipping elementwise first to prevent the larger gradients at the
    # end of the network from dominating after clipping the gradients by the global norm.
    info = apply_grad_clipping_elementwise(policy, optimizer, loss)

    # Update the grad clip value depending on the mode.
    if policy.config["grad_clip_options"]["mode"] == "adaptive":
        assert policy.config.get("grad_clip_elementwise", None) is not None
        if len(policy.prev_gradient_norms
               ) == policy.config["grad_clip_options"]["adaptive_buffer_size"]:
            # Compute the grad clip value as a percentile of the previous buffer_size gradient norms.
            grad_clip = np.percentile(policy.prev_gradient_norms,
                                      q=policy.config["grad_clip_options"]["adaptive_percentile"])
            # Clip the grad clip value to a reasonable range.
            grad_clip = np.clip(
                grad_clip,
                policy.config["grad_clip_options"]["adaptive_min"],
                policy.config["grad_clip_options"]["adaptive_max"],
            )
            # Update the grad clip value on the policy. This will take effect below.
            policy.config["grad_clip"] = grad_clip
            # Track the effective grad clip value as a metric.
            info["effective_grad_clip"] = grad_clip

        # Update buffer of gradients.
        current_gradient_norm = info["after_ele_clip_global_grad_norm"]
        policy.prev_gradient_norms.append(current_gradient_norm)

    # Apply gradient clipping per usual, possibly using the updated grad clip value.
    global_info = apply_grad_clipping(policy, optimizer, loss)
    if "grad_gnorm" in global_info:
        info["final_grad_global_norm"] = global_info["grad_gnorm"].to("cpu")

    return info
예제 #2
0
def grad_process_and_td_error_fn(policy, optimizer, loss):
    # Clip grads if configured.
    info = apply_grad_clipping(policy, optimizer, loss)
    # Add td-error to info dict.
    info["td_error"] = policy.q_loss.td_error
    return info
예제 #3
0
def grad_process_and_td_error_fn(policy: Policy,
                                 optimizer: "torch.optim.Optimizer",
                                 loss: TensorType) -> Dict[str, TensorType]:
    # Clip grads if configured.
    return apply_grad_clipping(policy, optimizer, loss)
예제 #4
0
def grad_process_and_td_error_fn(policy, optimizer, loss):
    # Clip grads if configured.
    return apply_grad_clipping(policy, optimizer, loss)