def _all_sum_grad(op, grad): """The gradients for `all_sum`. Args: op: The `all_sum` `Operation` that we are differentiating. grad: Gradient with respect to the output of the `all_sum` op. Returns: The gradient with respect to the output of `all_sum`. Raises: LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except sum.') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( input=grad, reduction='sum', num_devices=num_devices, shared_name=shared_name)
def _all_sum_grad(op, grad): """The gradients for `all_sum`. Args: op: The `all_sum` `Operation` that we are differentiating. grad: Gradient with respect to the output of the `all_sum` op. Returns: The gradient with respect to the output of `all_sum`. Raises: LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except for ' 'reduction="sum".') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce(input=grad, reduction='sum', num_devices=num_devices, shared_name=shared_name)
def _all_reduce(): """Call nccl allreduce.""" res = [] for t in tensors: _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce(input=t, reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) return res
def _all_reduce(): """Call nccl allreduce.""" res = [] for t in tensors: _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( input=t, reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) return res
def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') shared_name = _get_shared_name() res = [] for t in tensors: _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce(input=t, reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) return res
def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') shared_name = _get_shared_name() res = [] for t in tensors: _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( input=t, reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) return res