Example #1
0
def _set_operation(a, b, set_operation, validate_indices=True):
    """Compute set operation of elements in last dimension of `a` and `b`.

  All but the last dimension of `a` and `b` must match.

  Args:
    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
        must be sorted in row-major order.
    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
        sorted in row-major order.
    set_operation: String indicating set operaiton. See
        SetOperationOp::SetOperationFromContext for valid values.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a` and `b`.

  Returns:
    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
    dimension the same. Elements along the last dimension contain the results
    of the set operation.

  Raises:
    TypeError: If inputs are invalid types.
    ValueError: If `a` is sparse and `b` is dense.
  """
    a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
    if a.dtype.base_dtype not in _VALID_DTYPES:
        raise TypeError("'a' invalid dtype %s." % a.dtype)
    b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b")
    if b.dtype.base_dtype != a.dtype.base_dtype:
        raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
    # pylint: disable=protected-access
    if isinstance(a, sparse_tensor.SparseTensor):
        if isinstance(b, sparse_tensor.SparseTensor):
            indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
                a.indices, a.values, a.shape, b.indices, b.values, b.shape,
                set_operation, validate_indices)
        else:
            raise ValueError(
                "Sparse,Dense is not supported, but Dense,Sparse is. "
                "Please flip the order of your inputs.")
    elif isinstance(b, sparse_tensor.SparseTensor):
        indices, values, shape = _set_ops.dense_to_sparse_set_operation(
            a, b.indices, b.values, b.shape, set_operation, validate_indices)
    else:
        indices, values, shape = _set_ops.dense_to_dense_set_operation(
            a, b, set_operation, validate_indices)
    # pylint: enable=protected-access
    return sparse_tensor.SparseTensor(indices, values, shape)
Example #2
0
def _set_operation(a, b, set_operation, validate_indices=True):
  """Compute set operation of elements in last dimension of `a` and `b`.

  All but the last dimension of `a` and `b` must match.

  Args:
    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
        must be sorted in row-major order.
    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
        sorted in row-major order.
    set_operation: String indicating set operaiton. See
        SetOperationOp::SetOperationFromContext for valid values.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a` and `b`.

  Returns:
    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
    dimension the same. Elements along the last dimension contain the results
    of the set operation.

  Raises:
    TypeError: If inputs are invalid types.
    ValueError: If `a` is sparse and `b` is dense.
  """
  a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
  if a.dtype.base_dtype not in _VALID_DTYPES:
    raise TypeError("'a' invalid dtype %s." % a.dtype)
  b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b")
  if b.dtype.base_dtype != a.dtype.base_dtype:
    raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
  # pylint: disable=protected-access
  if isinstance(a, ops.SparseTensor):
    if isinstance(b, ops.SparseTensor):
      indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
          a.indices, a.values, a.shape, b.indices, b.values, b.shape,
          set_operation, validate_indices)
    else:
      raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
                       "Please flip the order of your inputs.")
  elif isinstance(b, ops.SparseTensor):
    indices, values, shape = _set_ops.dense_to_sparse_set_operation(
        a, b.indices, b.values, b.shape, set_operation, validate_indices)
  else:
    indices, values, shape = _set_ops.dense_to_dense_set_operation(
        a, b, set_operation, validate_indices)
  # pylint: enable=protected-access
  return ops.SparseTensor(indices, values, shape)
Example #3
0
def set_size(a, validate_indices=True):
  """Compute number of unique elements along last dimension of `a`.

  Args:
    a: `SparseTensor`, with indices sorted in row-major order.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a`.

  Returns:
    For `a` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st
    `n-1` dimensions as `a`. Each value is the number of unique elements in
    the corresponding `[0...n-1]` dimension of `a`.

  Raises:
    TypeError: If `a` is an invalid types.
  """
  a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
  if not isinstance(a, ops.SparseTensor):
    raise TypeError("Expected `SparseTensor`, got %s." % a)
  if a.values.dtype.base_dtype not in _VALID_DTYPES:
    raise TypeError("Invalid dtype %s." % a.values.dtype)
  # pylint: disable=protected-access
  return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
Example #4
0
def set_size(a, validate_indices=True):
  """Compute number of unique elements along last dimension of `a`.

  Args:
    a: `SparseTensor`, with indices sorted in row-major order.
    validate_indices: Whether to validate the order and range of sparse indices
       in `a`.

  Returns:
    `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
    rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
    number of unique elements in the corresponding `[0...n-1]` dimension of `a`.

  Raises:
    TypeError: If `a` is an invalid types.
  """
  a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
  if not isinstance(a, ops.SparseTensor):
    raise TypeError("Expected `SparseTensor`, got %s." % a)
  if a.values.dtype.base_dtype not in _VALID_DTYPES:
    raise TypeError("Invalid dtype %s." % a.values.dtype)
  # pylint: disable=protected-access
  return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)