def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, check_inf_nan): """Calculate the average gradient for a shared variable across all towers. Note that this function provides a synchronization point across all towers. Args: grad_and_vars: A list or tuple of (gradient, variable) tuples. Each (gradient, variable) pair within the outer list represents the gradient of the variable calculated for a single tower, and the number of pairs equals the number of towers. use_mean: if True, mean is taken, else sum of gradients is taken. check_inf_nan: check grads for nans and infs. Returns: The tuple ([(average_gradient, variable),], has_nan_or_inf) where the gradient has been averaged across all towers. The variable is chosen from the first tower. The has_nan_or_inf indicates the grads has nan or inf. """ grads = [g for g, _ in grad_and_vars] grad = math_ops.add_n(grads) if use_mean and len(grads) > 1: grad = array_ops.multiply(grad, 1.0 / len(grads)) v = grad_and_vars[0][1] if check_inf_nan: has_nan_or_inf = array_ops.logical_not( array_ops.reduce_all(array_ops.is_finite(grads))) return (grad, v), has_nan_or_inf else: return (grad, v), None
def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, check_inf_nan): """Calculate the average gradient for a shared variable across all replicas. Note that this function provides a synchronization point across all replicas. Args: grad_and_vars: A list or tuple of (gradient, variable) tuples. Each (gradient, variable) pair within the outer list represents the gradient of the variable calculated for a single replica, and the number of pairs equals the number of replicas. use_mean: if True, mean is taken, else sum of gradients is taken. check_inf_nan: check grads for nans and infs. Returns: The tuple ([(average_gradient, variable),], has_nan_or_inf) where the gradient has been averaged across all replicas. The variable is chosen from the first replica. The has_nan_or_inf indicates the grads has nan or inf. """ grads = [g for g, _ in grad_and_vars] grad = math_ops.add_n(grads) if use_mean and len(grads) > 1: grad = array_ops.multiply(grad, 1.0 / len(grads)) v = grad_and_vars[0][1] if check_inf_nan: has_nan_or_inf = array_ops.logical_not( array_ops.reduce_all(array_ops.is_finite(grads))) return (grad, v), has_nan_or_inf else: return (grad, v), None