def _average_multidevice_gradients(gradients, adasum=False): """Averages gradients over all the devices across different hosts.""" gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices n = fastmath.psum(jnp.array(1.0), 'batch') # number of devices on all hosts if not adasum: return fastmath.nested_map(lambda g: g / n, gradients_psum) # This implements an approximation of the Adasum algorithm from the following # paper: https://arxiv.org/pdf/2006.02924.pdf # Since implementing halving and averaging half-by-half is tricky, we first # average all hosts, so we use the sum as a point of comparison for gradients. # So for 2 devices, this algorithm is the same as in the paper, but with more # devices it does a different kind of averaging. It still has the property # that orthogonal gradients will result in a sum while identical ones will # be averaged, as postulated in the paper. adasum_nominator = fastmath.nested_map_multiarg( lambda g, q: jnp.vdot(g, q), # pylint: disable=unnecessary-lambda gradients, gradients_psum) grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients) # If all devices have identical gradients, then the nominator is equal # to n * grad_norm; if they're orthogonal, then nominator = grad_norm. scaled_grads = fastmath.nested_map_multiarg( lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) / (n * g_norm)), gradients, adasum_nominator, grad_norm) return fastmath.psum(scaled_grads, 'batch')
def _l2_norm(self, flat_list): """Returns the aggregate L2 norm of a list of tensors.""" if fastmath.is_backend(fastmath.Backend.JAX): norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list)) return norm
def l2_norm(tree): """Returns an L2 norm computed over all elements of all tensors in `tree`. Args: tree: Tree-structured collection of tensors, e.g., model weights matching the model's layer structure. Returns: A scalar value computed as if all the tensors in `tree` were combined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ leaves = fastmath.tree_flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def _l2_norm(self, flat_list): """Returns an L2-like norm of all elements of all tensors in `flat_list`. Args: flat_list: Collection of tensors as a flat list (rather than, e.g., a tree). Returns: A scalar value computed as if all the tensors in `flat_list` were joined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ if fastmath.is_backend(fastmath.Backend.JAX): norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list)) return norm
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves = fastmath.tree_flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def _adasum_merge(g1, g2): """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf.""" frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30) frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30) return (1 - frac1) * g1 + (1 - frac2) * g2