示例#1
0
def _all_sum_grad(op, grad):
  """The gradients for `all_sum`.

  Args:
    op: The `all_sum` `Operation` that we are differentiating.
    grad: Gradient with respect to the output of the `all_sum` op.

  Returns:
    The gradient with respect to the output of `all_sum`.

  Raises:
    LookupError: If `reduction` is not `sum`.
  """
  if op.get_attr('reduction') != b'sum':
    raise LookupError('No gradient defined for NcclAllReduce except sum.')

  _check_device(grad, expected=op.device)
  num_devices = op.get_attr('num_devices')
  shared_name = op.get_attr('shared_name') + b'_grad'

  with ops.device(op.device):
    return gen_nccl_ops.nccl_all_reduce(
        input=grad,
        reduction='sum',
        num_devices=num_devices,
        shared_name=shared_name)
示例#2
0
def _all_sum_grad(op, grad):
    """The gradients for `all_sum`.

  Args:
    op: The `all_sum` `Operation` that we are differentiating.
    grad: Gradient with respect to the output of the `all_sum` op.

  Returns:
    The gradient with respect to the output of `all_sum`.

  Raises:
    LookupError: If `reduction` is not `sum`.
  """
    if op.get_attr('reduction') != b'sum':
        raise LookupError('No gradient defined for NcclAllReduce except for '
                          'reduction="sum".')

    _check_device(grad, expected=op.device)
    num_devices = op.get_attr('num_devices')
    shared_name = op.get_attr('shared_name') + b'_grad'

    with ops.device(op.device):
        return gen_nccl_ops.nccl_all_reduce(input=grad,
                                            reduction='sum',
                                            num_devices=num_devices,
                                            shared_name=shared_name)
示例#3
0
 def _all_reduce():
     """Call nccl allreduce."""
     res = []
     for t in tensors:
         _check_device(t)
         with ops.device(t.device):
             res.append(
                 gen_nccl_ops.nccl_all_reduce(input=t,
                                              reduction=reduction,
                                              num_devices=len(tensors),
                                              shared_name=shared_name))
     return res
示例#4
0
 def _all_reduce():
   """Call nccl allreduce."""
   res = []
   for t in tensors:
     _check_device(t)
     with ops.device(t.device):
       res.append(
           gen_nccl_ops.nccl_all_reduce(
               input=t,
               reduction=reduction,
               num_devices=len(tensors),
               shared_name=shared_name))
   return res
示例#5
0
def _apply_all_reduce(reduction, tensors):
    """Helper function for all_* functions."""
    if not tensors:
        raise ValueError('Must pass >0 tensors to all reduce operations')

    shared_name = _get_shared_name()
    res = []

    for t in tensors:
        _check_device(t)
        with ops.device(t.device):
            res.append(
                gen_nccl_ops.nccl_all_reduce(input=t,
                                             reduction=reduction,
                                             num_devices=len(tensors),
                                             shared_name=shared_name))

    return res
示例#6
0
def _apply_all_reduce(reduction, tensors):
  """Helper function for all_* functions."""
  if not tensors:
    raise ValueError('Must pass >0 tensors to all reduce operations')

  shared_name = _get_shared_name()
  res = []

  for t in tensors:
    _check_device(t)
    with ops.device(t.device):
      res.append(
          gen_nccl_ops.nccl_all_reduce(
              input=t,
              reduction=reduction,
              num_devices=len(tensors),
              shared_name=shared_name))

  return res