コード例 #1
0
def _SparseAddGrad(op, *grads):
    """The backward operator for the SparseAdd op.

  The SparseAdd op calculates A + B, where A, B, and the sum are all represented
  as `SparseTensor` objects.  This op takes in the upstream gradient w.r.t.
  non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
  values of A and B.

  Args:
    op: the SparseAdd op
    *grads: the incoming gradients, one element per output of `op`

  Returns:
    Gradient for each of the 6 input tensors of SparseAdd:
      (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
    The gradients for the indices, shapes, and the threshold are None.
  """
    val_grad = grads[1]
    a_indices = op.inputs[0]
    b_indices = op.inputs[3]
    sum_indices = op.outputs[0]
    # NOTE: we do not need to take `thresh` into account, since it simply affects
    # the non-zero elements of the sum, and we will peek into `sum_indices` in the
    # gradient op.

    # pylint: disable=protected-access
    a_val_grad, b_val_grad = gen_sparse_ops._sparse_add_grad(
        val_grad, a_indices, b_indices, sum_indices)
    a_val_grad.set_shape(op.inputs[1].get_shape())
    b_val_grad.set_shape(op.inputs[4].get_shape())
    # (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
    return (None, a_val_grad, None, None, b_val_grad, None, None)
コード例 #2
0
ファイル: sparse_grad.py プロジェクト: 1000sprites/tensorflow
def _SparseAddGrad(op, *grads):
  """The backward operator for the SparseAdd op.

  The SparseAdd op calculates A + B, where A, B, and the sum are all represented
  as `SparseTensor` objects.  This op takes in the upstream gradient w.r.t.
  non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
  values of A and B.

  Args:
    op: the SparseAdd op
    *grads: the incoming gradients, one element per output of `op`

  Returns:
    Gradient for each of the 6 input tensors of SparseAdd:
      (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
    The gradients for the indices, shapes, and the threshold are None.
  """
  val_grad = grads[1]
  a_indices = op.inputs[0]
  b_indices = op.inputs[3]
  sum_indices = op.outputs[0]
  # NOTE: we do not need to take `thresh` into account, since it simply affects
  # the non-zero elements of the sum, and we will peek into `sum_indices` in the
  # gradient op.

  # pylint: disable=protected-access
  a_val_grad, b_val_grad = gen_sparse_ops._sparse_add_grad(val_grad, a_indices,
                                                           b_indices,
                                                           sum_indices)
  a_val_grad.set_shape(op.inputs[1].get_shape())
  b_val_grad.set_shape(op.inputs[4].get_shape())
  # (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
  return (None, a_val_grad, None, None, b_val_grad, None, None)