Beispiel #1
0
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
        `rhs` is treated like a [batch] matrix meaning for every set of leading
        dimensions, the last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        with self._name_scope(name, values=[rhs]):
            rhs = ops.convert_to_tensor(rhs, name="rhs")
            self._check_input_dtype(rhs)

            self_dim = -1 if adjoint else -2
            arg_dim = -1 if adjoint_arg else -2
            tensor_shape.dimension_at_index(
                self.shape,
                self_dim).assert_is_compatible_with(rhs.get_shape()[arg_dim])

            return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
  def matvec(self, x, adjoint=False, name="matvec"):
    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)

    X = ... # shape [..., N], batch vector

    Y = operator.matvec(X)
    Y.shape
    ==> [..., M]

    Y[..., :] = sum_j A[..., :, j] X[..., j]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`.
        `x` is treated as a [batch] vector meaning for every set of leading
        dimensions, the last dimension defines a vector.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      name:  A name for this `Op`.

    Returns:
      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
    """
    with self._name_scope(name):
      x = ops.convert_to_tensor(x, name="x")
      self._check_input_dtype(x)
      self_dim = -2 if adjoint else -1
      tensor_shape.dimension_at_index(
          self.shape, self_dim).assert_is_compatible_with(x.get_shape()[-1])
      return self._matvec(x, adjoint=adjoint)
  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
        `rhs` is treated like a [batch] matrix meaning for every set of leading
        dimensions, the last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
    if self.is_non_singular is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "be singular.")
    if self.is_square is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "not be square.")
    with self._name_scope(name, values=[rhs]):
      rhs = ops.convert_to_tensor(rhs, name="rhs")
      self._check_input_dtype(rhs)

      self_dim = -1 if adjoint else -2
      arg_dim = -1 if adjoint_arg else -2
      tensor_shape.dimension_at_index(
          self.shape, self_dim).assert_is_compatible_with(
              rhs.get_shape()[arg_dim])

      return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
Beispiel #4
0
    def static_nrows(self):
        """The number of rows in this partition, if statically known.

    ```python
    self.row_lengths().shape == [self.static_nrows]
    self.row_starts().shape == [self.static_nrows]
    self.row_limits().shape == [self.static_nrows]
    self.row_splits().shape == [self.static_nrows + 1]
    ```

    Returns:
      The number of rows in this partition as an `int` (if statically known);
      or `None` (otherwise).
    """
        if self._row_splits is not None:
            nrows = tensor_shape.dimension_at_index(self._row_splits.shape,
                                                    0) - 1
            if nrows.value is not None:
                return nrows
        if self._row_lengths is not None:
            nrows = tensor_shape.dimension_at_index(self._row_lengths.shape, 0)
            if nrows.value is not None:
                return nrows
        if self._nrows is not None:
            return tensor_shape.Dimension(
                tensor_util.constant_value(self._nrows))
        return None
Beispiel #5
0
    def matvec(self, x, adjoint=False, name="matvec"):
        """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)

    X = ... # shape [..., N], batch vector

    Y = operator.matvec(X)
    Y.shape
    ==> [..., M]

    Y[..., :] = sum_j A[..., :, j] X[..., j]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`.
        `x` is treated as a [batch] vector meaning for every set of leading
        dimensions, the last dimension defines a vector.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      name:  A name for this `Op`.

    Returns:
      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
    """
        with self._name_scope(name):
            x = ops.convert_to_tensor(x, name="x")
            self._check_input_dtype(x)
            self_dim = -2 if adjoint else -1
            tensor_shape.dimension_at_index(
                self.shape,
                self_dim).assert_is_compatible_with(x.get_shape()[-1])
            return self._matvec(x, adjoint=adjoint)
Beispiel #6
0
def _verify_input(tensor_list, labels, probs_list):
  """Verify that batched inputs are well-formed."""
  checked_probs_list = []
  for probs in probs_list:
    # Since number of classes shouldn't change at runtime, probabilities shape
    # should be fully defined.
    probs.get_shape().assert_is_fully_defined()

    # Probabilities must be 1D.
    probs.get_shape().assert_has_rank(1)

    # Probabilities must be nonnegative and sum to one.
    tol = 1e-6
    prob_sum = math_ops.reduce_sum(probs)
    checked_probs = control_flow_ops.with_dependencies([
        check_ops.assert_non_negative(probs),
        check_ops.assert_less(prob_sum, 1.0 + tol),
        check_ops.assert_less(1.0 - tol, prob_sum)
    ], probs)
    checked_probs_list.append(checked_probs)

  # All probabilities should be the same length.
  prob_length = checked_probs_list[0].get_shape().num_elements()
  for checked_prob in checked_probs_list:
    if checked_prob.get_shape().num_elements() != prob_length:
      raise ValueError('Probability parameters must have the same length.')

  # Labels tensor should only have batch dimension.
  labels.get_shape().assert_has_rank(1)

  for tensor in tensor_list:
    # Data tensor should have a batch dimension.
    shape = tensor.get_shape().with_rank_at_least(1)

    # Data and label batch dimensions must be compatible.
    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
        labels.get_shape()[0])

  # Data and labels must have the same, strictly positive batch size. Since we
  # can't assume we know the batch size at graph creation, add runtime checks.
  labels_batch_size = array_ops.shape(labels)[0]
  lbl_assert = check_ops.assert_positive(labels_batch_size)

  # Make each tensor depend on its own checks.
  labels = control_flow_ops.with_dependencies([lbl_assert], labels)
  tensor_list = [
      control_flow_ops.with_dependencies([
          lbl_assert,
          check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
      ], x) for x in tensor_list
  ]

  # Label's classes must be integers 0 <= x < num_classes.
  labels = control_flow_ops.with_dependencies([
      check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
      check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
  ], labels)

  return tensor_list, labels, checked_probs_list
Beispiel #7
0
def _verify_input(tensor_list, labels, probs_list):
  """Verify that batched inputs are well-formed."""
  checked_probs_list = []
  for probs in probs_list:
    # Since number of classes shouldn't change at runtime, probabilities shape
    # should be fully defined.
    probs.get_shape().assert_is_fully_defined()

    # Probabilities must be 1D.
    probs.get_shape().assert_has_rank(1)

    # Probabilities must be nonnegative and sum to one.
    tol = 1e-6
    prob_sum = math_ops.reduce_sum(probs)
    checked_probs = control_flow_ops.with_dependencies([
        check_ops.assert_non_negative(probs),
        check_ops.assert_less(prob_sum, 1.0 + tol),
        check_ops.assert_less(1.0 - tol, prob_sum)
    ], probs)
    checked_probs_list.append(checked_probs)

  # All probabilities should be the same length.
  prob_length = checked_probs_list[0].get_shape().num_elements()
  for checked_prob in checked_probs_list:
    if checked_prob.get_shape().num_elements() != prob_length:
      raise ValueError('Probability parameters must have the same length.')

  # Labels tensor should only have batch dimension.
  labels.get_shape().assert_has_rank(1)

  for tensor in tensor_list:
    # Data tensor should have a batch dimension.
    shape = tensor.get_shape().with_rank_at_least(1)

    # Data and label batch dimensions must be compatible.
    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
        labels.get_shape()[0])

  # Data and labels must have the same, strictly positive batch size. Since we
  # can't assume we know the batch size at graph creation, add runtime checks.
  labels_batch_size = array_ops.shape(labels)[0]
  lbl_assert = check_ops.assert_positive(labels_batch_size)

  # Make each tensor depend on its own checks.
  labels = control_flow_ops.with_dependencies([lbl_assert], labels)
  tensor_list = [
      control_flow_ops.with_dependencies([
          lbl_assert,
          check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
      ], x) for x in tensor_list
  ]

  # Label's classes must be integers 0 <= x < num_classes.
  labels = control_flow_ops.with_dependencies([
      check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
      check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
  ], labels)

  return tensor_list, labels, checked_probs_list
Beispiel #8
0
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    Y.shape
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
        `self`. See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`.
    """
        if isinstance(x, LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = x.adjoint() if adjoint_arg else x

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `x` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.matmul(left_operator,
                                                      right_operator)

        with self._name_scope(name):
            x = ops.convert_to_tensor(x, name="x")
            self._check_input_dtype(x)

            self_dim = -2 if adjoint else -1
            arg_dim = -1 if adjoint_arg else -2
            tensor_shape.dimension_at_index(
                self.shape,
                self_dim).assert_is_compatible_with(x.get_shape()[arg_dim])

            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
Beispiel #9
0
def _verify_data_inputs(tensor_list):
    """Verify that batched data inputs are well-formed."""
    for tensor in tensor_list:
        # Data tensor should have a batch dimension.
        shape = tensor.get_shape().with_rank_at_least(1)

        # Data batch dimensions must be compatible.
        tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
            tensor_list[0].get_shape()[0])

    return tensor_list
  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    Y.shape
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
        `self`. See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`.
    """
    if isinstance(x, LinearOperator):
      left_operator = self.adjoint() if adjoint else self
      right_operator = x.adjoint() if adjoint_arg else x

      if (right_operator.range_dimension is not None and
          left_operator.domain_dimension is not None and
          right_operator.range_dimension != left_operator.domain_dimension):
        raise ValueError(
            "Operators are incompatible. Expected `x` to have dimension"
            " {} but got {}.".format(
                left_operator.domain_dimension, right_operator.range_dimension))
      with self._name_scope(name):
        return linear_operator_algebra.matmul(left_operator, right_operator)

    with self._name_scope(name):
      x = ops.convert_to_tensor(x, name="x")
      self._check_input_dtype(x)

      self_dim = -2 if adjoint else -1
      arg_dim = -1 if adjoint_arg else -2
      tensor_shape.dimension_at_index(
          self.shape, self_dim).assert_is_compatible_with(
              x.get_shape()[arg_dim])

      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
Beispiel #11
0
def _verify_data_inputs(tensor_list):
  """Verify that batched data inputs are well-formed."""
  for tensor in tensor_list:
    # Data tensor should have a batch dimension.
    shape = tensor.get_shape().with_rank_at_least(1)

    # Data batch dimensions must be compatible.
    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
        tensor_list[0].get_shape()[0])

  return tensor_list
  def solvevec(self, rhs, adjoint=False, name="solve"):
    """Solve single equation with best effort: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve one linear system for every member of the batch.
    RHS = ... # shape [..., M]

    X = operator.solvevec(RHS)
    # X is the solution to the linear system
    # sum_j A[..., :, j] X[..., j] = RHS[..., :]

    operator.matvec(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator.
        `rhs` is treated like a [batch] vector meaning for every set of leading
        dimensions, the last dimension defines a vector.  See class docstring
        for definition of compatibility regarding batch dimensions.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
    with self._name_scope(name):
      rhs = ops.convert_to_tensor(rhs, name="rhs")
      self._check_input_dtype(rhs)
      self_dim = -1 if adjoint else -2
      tensor_shape.dimension_at_index(
          self.shape, self_dim).assert_is_compatible_with(
              rhs.get_shape()[-1])

      return self._solvevec(rhs, adjoint=adjoint)
Beispiel #13
0
    def solvevec(self, rhs, adjoint=False, name="solve"):
        """Solve single equation with best effort: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve one linear system for every member of the batch.
    RHS = ... # shape [..., M]

    X = operator.solvevec(RHS)
    # X is the solution to the linear system
    # sum_j A[..., :, j] X[..., j] = RHS[..., :]

    operator.matvec(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator.
        `rhs` is treated like a [batch] vector meaning for every set of leading
        dimensions, the last dimension defines a vector.  See class docstring
        for definition of compatibility regarding batch dimensions.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        with self._name_scope(name):
            rhs = ops.convert_to_tensor(rhs, name="rhs")
            self._check_input_dtype(rhs)
            self_dim = -1 if adjoint else -2
            tensor_shape.dimension_at_index(
                self.shape,
                self_dim).assert_is_compatible_with(rhs.get_shape()[-1])

            return self._solvevec(rhs, adjoint=adjoint)
def _merge_nrows(nrows, static_nrows, value, dtype, validate):
    """Merges `nrows` with `nrows(value)`.

  Checks that `value` has the expected number of rows (`nrows`), and returns
  `nrows`.  If `validate` is true, then add validation ops that check that
  the `nrows` values match.

  Args:
    nrows: scalar integer Tensor.
    static_nrows: tf.Dimension: static value of nrows, if known.
    value: Tensor or RaggedTensor or StructuredTensor
    dtype: dtype for `nrows`.
    validate: bool -- whether to add validation ops.

  Returns:
    A tuple `(nrows, static_nrows)`.
  """
    static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
    if isinstance(value, ops.Tensor):
        value_nrows = array_ops.shape(value, out_type=dtype)[0]
    else:
        value_nrows = value.nrows()
    if nrows is None:
        nrows = value_nrows
    elif (static_value_nrows.value is not None
          and static_nrows.value is not None):
        if not static_value_nrows.is_compatible_with(static_nrows):
            raise ValueError('fields have incompatible nrows')
        nrows = value_nrows  # No need to add an assertion op.
    elif validate:
        nrows = control_flow_ops.with_dependencies([
            check_ops.assert_equal(
                nrows, value_nrows, message='fields have incompatible nrows')
        ], nrows)
    return nrows, static_nrows.merge_with(static_value_nrows)
Beispiel #15
0
    def nrows(self, out_type=None, name=None):
        """Returns the number of rows in this ragged tensor.

    I.e., the size of the outermost dimension of the tensor.

    Args:
      out_type: `dtype` for the returned tensor.  Defaults to
        `self.row_splits.dtype`.
      name: A name prefix for the returned tensor (optional).

    Returns:
      A scalar `Tensor` with dtype `out_type`.

    """
        if out_type is None:
            out_type = self._row_splits.dtype
        else:
            out_type = dtypes.as_dtype(out_type)
        if self._cached_nrows is not None:
            return math_ops.cast(self._cached_nrows, out_type)
        with ops.name_scope(name, "RaggedNRows", [self]):
            nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
            if nsplits.value is None:
                return array_ops.shape(self.row_splits,
                                       out_type=out_type)[0] - 1
            else:
                return constant_op.constant(nsplits.value - 1, dtype=out_type)
  def _check_shapes(self):
    """Static check that shapes are compatible."""
    # Broadcast shape also checks that u and v are compatible.
    uv_shape = array_ops.broadcast_static_shape(
        self.u.get_shape(), self.v.get_shape())

    batch_shape = array_ops.broadcast_static_shape(
        self.base_operator.batch_shape, uv_shape[:-2])

    tensor_shape.Dimension(
        self.base_operator.domain_dimension).assert_is_compatible_with(
            uv_shape[-2])

    if self._diag_update is not None:
      tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with(
          self._diag_update.get_shape()[-1])
      array_ops.broadcast_static_shape(
          batch_shape, self._diag_update.get_shape()[:-1])
Beispiel #17
0
    def _check_shapes(self):
        """Static check that shapes are compatible."""
        # Broadcast shape also checks that u and v are compatible.
        uv_shape = array_ops.broadcast_static_shape(self.u.shape, self.v.shape)

        batch_shape = array_ops.broadcast_static_shape(
            self.base_operator.batch_shape, uv_shape[:-2])

        tensor_shape.Dimension(
            self.base_operator.domain_dimension).assert_is_compatible_with(
                uv_shape[-2])

        if self._diag_update is not None:
            tensor_shape.dimension_at_index(uv_shape,
                                            -1).assert_is_compatible_with(
                                                self._diag_update.shape[-1])
            array_ops.broadcast_static_shape(batch_shape,
                                             self._diag_update.shape[:-1])
Beispiel #18
0
def ragged_tensor_spec(shape=None,
                       dtype=dtypes.float32,
                       ragged_rank=None,
                       row_splits_dtype=dtypes.int64,
                       name=None):
    """Returns a tensor specification for a RaggedTensor.

  Returns an object which can be passed to `tf.function` (or other
  functions that expect `TensorSpec`s) to specify shape constraints
  for a `RaggedTensor` argument.

  Args:
    shape: The shape of the RaggedTensor, or `None` to allow any shape.
    dtype: Data type of values in the RaggedTensor.
    ragged_rank: Python integer, the ragged rank of the RaggedTensor
      to be described.  Defaults to `shape.ndims - 1`.
    row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
      One of `tf.int32` or `tf.int64`.
    name: Optional name prefix for the `TensorSpec`s.

  Returns:
    An object describing the `flat_values` and `nested_row_splits` tensors
    that comprise the `RaggedTensor`.
  """
    dtype = dtypes.as_dtype(dtype)
    shape = tensor_shape.TensorShape(shape)
    if ragged_rank is None:
        if shape.ndims is None:
            raise ValueError(
                "Must specify ragged_rank or a shape with known rank.")
        ragged_rank = shape.ndims - 1
    elif not isinstance(ragged_rank, int):
        raise TypeError("ragged_rank must be an int")
    if ragged_rank == 0:
        return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)

    result = tensor_spec.TensorSpec(
        tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
        dtype, name)

    for i in range(ragged_rank - 1, 0, -1):
        splits = tensor_spec.TensorSpec([None], row_splits_dtype,
                                        "%s.row_splits_%d" %
                                        (name, i) if name else None)
        result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)

    outer_dim = tensor_shape.dimension_at_index(shape, 0)
    splits_shape = [None if outer_dim is None else outer_dim + 1]
    splits = tensor_spec.TensorSpec(splits_shape, row_splits_dtype,
                                    "%s.row_splits_0" % name if name else None)
    result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)

    return result
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    Y.shape
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op.

    Returns:
      A `Tensor` with shape `[..., M, R]` and same `dtype` as `self`.
    """
        with self._name_scope(name, values=[x]):
            x = ops.convert_to_tensor(x, name="x")
            self._check_input_dtype(x)

            self_dim = -2 if adjoint else -1
            arg_dim = -1 if adjoint_arg else -2
            tensor_shape.dimension_at_index(
                self.shape,
                self_dim).assert_is_compatible_with(x.get_shape()[arg_dim])

            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
  def shape(self):
    """The statically known shape of this RaggedStructuredTensor."""
    nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1

    if self._uniform_row_length is not None:
      row_length = tensor_util.constant_value(self._uniform_row_length)
    else:
      row_length = None

    values_shape = self._values.shape
    value_shape = values_shape[1:]
    return tensor_shape.TensorShape([nrows,
                                     row_length]).concatenate(value_shape)
  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    Y.shape
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op.

    Returns:
      A `Tensor` with shape `[..., M, R]` and same `dtype` as `self`.
    """
    with self._name_scope(name, values=[x]):
      x = ops.convert_to_tensor(x, name="x")
      self._check_input_dtype(x)

      self_dim = -2 if adjoint else -1
      arg_dim = -1 if adjoint_arg else -2
      tensor_shape.dimension_at_index(
          self.shape, self_dim).assert_is_compatible_with(
              x.get_shape()[arg_dim])

      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
                       ragged_rank=None, row_splits_dtype=dtypes.int64,
                       name=None):
  """Returns a tensor specification for a RaggedTensor.

  Returns an object which can be passed to `tf.function` (or other
  functions that expect `TensorSpec`s) to specify shape constraints
  for a `RaggedTensor` argument.

  Args:
    shape: The shape of the RaggedTensor, or `None` to allow any shape.
    dtype: Data type of values in the RaggedTensor.
    ragged_rank: Python integer, the ragged rank of the RaggedTensor
      to be described.  Defaults to `shape.ndims - 1`.
    row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
      One of `tf.int32` or `tf.int64`.
    name: Optional name prefix for the `TensorSpec`s.

  Returns:
    An object describing the `flat_values` and `nested_row_splits` tensors
    that comprise the `RaggedTensor`.
  """
  dtype = dtypes.as_dtype(dtype)
  shape = tensor_shape.TensorShape(shape)
  if ragged_rank is None:
    if shape.ndims is None:
      raise ValueError("Must specify ragged_rank or a shape with known rank.")
    ragged_rank = shape.ndims - 1
  elif not isinstance(ragged_rank, int):
    raise TypeError("ragged_rank must be an int")
  if ragged_rank == 0:
    return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)

  result = tensor_spec.TensorSpec(
      tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
      dtype, name)

  for i in range(ragged_rank - 1, 0, -1):
    splits = tensor_spec.TensorSpec(
        [None], row_splits_dtype,
        "%s.row_splits_%d" % (name, i) if name else None)
    result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)

  outer_dim = tensor_shape.dimension_at_index(shape, 0)
  splits_shape = [None if outer_dim is None else outer_dim + 1]
  splits = tensor_spec.TensorSpec(
      splits_shape, row_splits_dtype,
      "%s.row_splits_0" % name if name else None)
  result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)

  return result
Beispiel #23
0
  def static_nvals(self):
    """The number of values in this partition, if statically known.

    ```python
    self.value_rowids().shape == [self.static_vals]
    ```

    Returns:
      The number of values in this partition as an `int` (if statically known);
      or `None` (otherwise).
    """
    if self._value_rowids is not None:
      nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0)
      if nvals.value is not None:
        return nvals.value
    return None
Beispiel #24
0
def _replace_ragged_with_flat_values(value, partition_lists,
                                     flat_values_nrows):
    """Replace RaggedTensors with their flat_values, and record their partitions.

  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
  `flat_values` tensor.  Looks inside lists, tuples, and dicts.

  Appends each `RaggedTensor`'s `RowPartition`s to `partition_lists`.

  Args:
    value: The value that should be transformed by replacing `RaggedTensors`.
    partition_lists: An output parameter used to record the row partitions
      for any `RaggedTensors` that were replaced.
    flat_values_nrows: An output parameter used to record the outer dimension
      size for each replacement `flat_values` (when known).  Contains a list of
      int.

  Returns:
    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
  """
    # Base case
    if ragged_tensor.is_ragged(value):
        value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
        partition_lists.append(value._nested_row_partitions)  # pylint: disable=protected-access
        nrows = tensor_shape.dimension_at_index(value.flat_values.shape,
                                                0).value
        if nrows is not None:
            flat_values_nrows.append(nrows)
        return value.flat_values

    # Recursion cases
    def recurse(v):
        return _replace_ragged_with_flat_values(v, partition_lists,
                                                flat_values_nrows)

    if isinstance(value, list):
        return [recurse(v) for v in value]
    elif isinstance(value, tuple):
        return tuple(recurse(v) for v in value)
    elif isinstance(value, dict):
        return dict((k, recurse(v)) for (k, v) in value.items())
    else:
        return value
Beispiel #25
0
    def nrows(self, out_type=None):
        """Returns the number of rows created by this `RowPartition`.

    Args:
      out_type: `dtype` for the returned tensor.  Defaults to `self.dtype`.

    Returns:
      scalar integer Tensor
    """
        if out_type is None:
            out_type = self.dtype
        else:
            out_type = dtypes.as_dtype(out_type)
        if self._nrows is not None:
            return math_ops.cast(self._nrows, out_type)
        nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0)
        if nsplits.value is None:
            return array_ops.shape(self._row_splits, out_type=out_type)[0] - 1
        else:
            return constant_op.constant(nsplits.value - 1, dtype=out_type)
Beispiel #26
0
    def shape(self):
        """The statically known shape of this ragged tensor.

    Returns:
      A `TensorShape` containing the statically known shape of this ragged
      tensor.  Ragged dimensions have a size of `None`.

    Examples:

      ```python
      >>> ragged.constant([[0], [1, 2]]).shape
      TensorShape([Dimension(2), Dimension(None)])

      >>> ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
      TensorShape([Dimension(2), Dimension(None), Dimension(2)
      ```
    """
        nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1

        values_shape = self._values.shape
        value_shape = values_shape[1:]
        return tensor_shape.TensorShape([nrows, None]).concatenate(value_shape)
Beispiel #27
0
  def shape(self):
    """The statically known shape of this ragged tensor.

    Returns:
      A `TensorShape` containing the statically known shape of this ragged
      tensor.  Ragged dimensions have a size of `None`.

    Examples:

      ```python
      >>> ragged.constant([[0], [1, 2]]).shape
      TensorShape([Dimension(2), Dimension(None)])

      >>> ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
      TensorShape([Dimension(2), Dimension(None), Dimension(2)
      ```
    """
    nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1

    values_shape = self._values.shape
    value_shape = values_shape[1:]
    return tensor_shape.TensorShape([nrows, None]).concatenate(value_shape)
Beispiel #28
0
  def nrows(self, out_type=None, name=None):
    """Returns the number of rows created by this `RowPartition`.

    Args:
      out_type: `dtype` for the returned tensor.  Defaults to
        `self.dtype`.
      name: A name prefix for the returned tensor (optional).

    Returns:
      scalar integer Tensor
    """
    if out_type is None:
      out_type = self.dtype
    else:
      out_type = dtypes.as_dtype(out_type)
    if self._cached_nrows is not None:
      return math_ops.cast(self._cached_nrows, out_type)
    with ops.name_scope(name, "RaggedNRows", [self]):
      nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
      if nsplits.value is None:
        return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
      else:
        return constant_op.constant(nsplits.value - 1, dtype=out_type)
def _ragged_getitem(rt_input, key_list):
    """Helper for indexing and slicing ragged tensors with __getitem__().

  Extracts the specified piece of the `rt_input`.  See
  `RaggedTensor.__getitem__` for examples and restrictions.

  Args:
    rt_input: The `RaggedTensor` from which a piece should be returned.
    key_list: The list of keys specifying which piece to return. Each key
      corresponds with a separate dimension.

  Returns:
    The indicated piece of rt_input.

  Raises:
    ValueError: If `key_list` is not supported.
    TypeError: If any keys in `key_list` have an unsupported type.
  """
    if not key_list:
        return rt_input
    row_key = key_list[0]
    inner_keys = key_list[1:]

    if row_key is Ellipsis:
        expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims)
        return _ragged_getitem(rt_input, expanded_key_list)

    # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor
    # that puts all values in a single row.
    if row_key is array_ops.newaxis:
        inner_rt = _ragged_getitem(rt_input, inner_keys)
        nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
        if nsplits.value is not None:
            nsplits = nsplits.value
        else:
            nsplits = array_ops.shape(inner_rt.row_splits,
                                      out_type=inner_rt.row_splits.dtype)[0]
        return ragged_tensor.RaggedTensor.from_uniform_row_length(
            inner_rt, nsplits - 1, nrows=1, validate=False)

    # Slicing a range of rows: first slice the outer dimension, and then
    # call `_ragged_getitem_inner_dimensions` to handle the inner keys.
    if isinstance(row_key, slice):
        sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
        if rt_input.uniform_row_length is not None:
            # If the inner dimension has uniform_row_length, then preserve it (by
            # re-wrapping the values in a new RaggedTensor).  Note that the row
            # length won't have changed, since we're slicing a range of rows (and not
            # slicing the rows themselves).
            sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
                sliced_rt_input.values,
                rt_input.uniform_row_length,
                nrows=sliced_rt_input.nrows())
        return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)

    # Indexing a single row: slice values to get the indicated row, and then
    # use a recursive call to __getitem__ to handle the inner keys.
    else:
        starts = rt_input.row_splits[:-1]
        limits = rt_input.row_splits[1:]
        if context.executing_eagerly():
            # In python, __getitem__ should throw IndexError for out of bound
            # indices. This will allow iteration run correctly as python will
            # translate IndexError into StopIteration for next()/__next__().
            # Below is an example:
            #    import tensorflow as tf
            #    r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]])
            #    for elem in r:
            #      print(elem)
            # In non eager mode, the exception is thrown when session runs
            # so we don't know if out of bound happens before.
            # In eager mode, however, it is possible to find out when to
            # throw out of bound IndexError.
            # In the following row_key >= len(starts) is checked. In case of
            # TypeError which happens when row_key is not an integer, the exception
            # will simply be ignored as it will be processed later anyway.
            try:
                if int(row_key) >= len(starts):
                    raise IndexError(
                        "Row key {} out of bounds".format(row_key))
            except (TypeError, ValueError):
                pass
        row = rt_input.values[starts[row_key]:limits[row_key]]
        return row.__getitem__(inner_keys)
Beispiel #30
0
 def nrows_as_dimension(self):
   """Returns the first dimension of the shape as a `tf.Dimension`."""
   return tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
Beispiel #31
0
 def _component_specs(self):
     row_splits_shape = tensor_shape.TensorShape(
         [tensor_shape.dimension_at_index(self._nrows, 0) + 1])
     return tensor_spec.TensorSpec(row_splits_shape, self._dtype)
def _ragged_getitem_inner_dimensions(rt_input, key_list):
    """Retrieve inner dimensions, keeping outermost dimension unchanged.

  Args:
    rt_input: The `RaggedTensor` or `Tensor` from which a piece should be
      extracted.
    key_list: The __getitem__ keys for slicing the inner dimensions.

  Returns:
    A `RaggedTensor`.

  Raises:
    ValueError: If key_list is not supported.
  """
    if not key_list:
        return rt_input

    if isinstance(rt_input, ops.Tensor):
        return rt_input.__getitem__([slice(None, None, None)] + key_list)

    column_key = key_list[0]
    if column_key is Ellipsis:
        expanded_key_list = _expand_ellipsis(key_list,
                                             rt_input.values.shape.ndims)
        return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list)

    # Adding a new axis to a ragged inner dimension: recursively get the inner
    # dimensions of rt_input with key_list[1:], and then wrap the result in a
    # RaggedTensor that puts each value in its own row.
    if column_key is array_ops.newaxis:
        inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
        nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
        if nsplits.value is not None:
            nsplits = nsplits.value
        else:
            nsplits = array_ops.shape(inner_rt.row_splits,
                                      out_type=inner_rt.row_splits.dtype)[0]
        return ragged_tensor.RaggedTensor.from_uniform_row_length(
            inner_rt, 1, nrows=nsplits - 1, validate=False)

    # Slicing a range of columns in a ragged inner dimension.  We use a
    # recursive call to process the values, and then assemble a RaggedTensor
    # with those values.
    if isinstance(column_key, slice):
        if (column_key.start is None and column_key.stop is None
                and column_key.step is None):
            # Trivial slice: recursively process all values, & splits is unchanged.
            return rt_input.with_values(
                _ragged_getitem_inner_dimensions(rt_input.values,
                                                 key_list[1:]))
        else:
            if not (isinstance(column_key.start, (ops.Tensor, int, type(None)))
                    and isinstance(column_key.stop,
                                   (ops.Tensor, int, type(None)))):
                raise TypeError("slice offsets must be integers or None")

            # Nontrivial slice: use ragged_gather to extract the indicated slice as
            # a new RaggedTensor (inner_rt), and then recursively process its values.
            starts = rt_input.row_splits[:-1]
            limits = rt_input.row_splits[1:]
            step = 1 if column_key.step is None else column_key.step
            lower_bound = _if_ge_zero(step, lambda: starts, lambda: starts - 1)
            upper_bound = _if_ge_zero(step, lambda: limits, lambda: limits - 1)
            # inner_rt_starts[i] = index to start gathering for row i.
            if column_key.start is None:
                inner_rt_starts = _if_ge_zero(step, lambda: starts,
                                              lambda: limits - 1)
            else:
                start_offset = math_ops.cast(column_key.start, starts.dtype)
                inner_rt_starts = _if_ge_zero(
                    column_key.start, lambda: math_ops.minimum(
                        starts + start_offset, upper_bound), lambda: math_ops.
                    maximum(limits + start_offset, lower_bound))
            # inner_rt_limits[i] = index to stop gathering for row i.
            if column_key.stop is None:
                inner_rt_limits = _if_ge_zero(step, lambda: limits,
                                              lambda: starts - 1)
            else:
                stop_offset = math_ops.cast(column_key.stop, starts.dtype)
                inner_rt_limits = _if_ge_zero(
                    column_key.stop, lambda: math_ops.minimum(
                        starts + stop_offset, upper_bound), lambda: math_ops.
                    maximum(limits + stop_offset, lower_bound))
            inner_rt = _build_ragged_tensor_from_value_ranges(
                inner_rt_starts, inner_rt_limits, column_key.step,
                rt_input.values)
            # If the row dimension is uniform, then calculate the new
            # uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
            if rt_input.uniform_row_length is not None:
                new_row_length = _slice_length(rt_input.uniform_row_length,
                                               column_key)
                inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
                    inner_rt.values, new_row_length, rt_input.nrows())
            return inner_rt.with_values(
                _ragged_getitem_inner_dimensions(inner_rt.values,
                                                 key_list[1:]))

    # Indexing a single column in a ragged inner dimension: raise an Exception.
    # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
    # into a ragged inner dimension is problematic.
    if rt_input.uniform_row_length is None:
        raise ValueError("Cannot index into an inner ragged dimension.")

    # Indexing a single column in a uniform inner dimension: check that the
    # given index is in-bounds, and then use a strided slice over rt_input.values
    # to take the indicated element from each row.
    row_length = rt_input.uniform_row_length
    column_key = math_ops.cast(column_key, row_length.dtype)
    oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
    oob_checks = [
        check_ops.assert_greater_equal(column_key,
                                       -row_length,
                                       message=oob_err_msg),
        check_ops.assert_less(column_key, row_length, message=oob_err_msg),
    ]
    with ops.control_dependencies(oob_checks):
        offset = _if_ge_zero(column_key, lambda: column_key,
                             lambda: row_length + column_key)
        sliced_rt = rt_input.values[offset::row_length]
        return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
Beispiel #33
0
    def __init__(
        self,
        input_shape,
        filter_shape,
        padding,
        strides=None,
        dilation_rate=None,
        name=None,
        data_format=None,
    ):
        """Helper function for convolution."""
        num_total_dims = filter_shape.ndims
        if num_total_dims is None:
            num_total_dims = input_shape.ndims
        if num_total_dims is None:
            raise ValueError("rank of input or filter must be known")

        num_spatial_dims = num_total_dims - 2

        try:
            input_shape.with_rank(num_spatial_dims + 2)
        except ValueError:
            raise ValueError("input tensor must have rank %d" %
                             (num_spatial_dims + 2))

        try:
            filter_shape.with_rank(num_spatial_dims + 2)
        except ValueError:
            raise ValueError("filter tensor must have rank %d" %
                             (num_spatial_dims + 2))

        if data_format is None or not data_format.startswith("NC"):
            input_channels_dim = tensor_shape.dimension_at_index(
                input_shape, num_spatial_dims + 1)
            spatial_dims = range(1, num_spatial_dims + 1)
        else:
            input_channels_dim = tensor_shape.dimension_at_index(
                input_shape, 1)
            spatial_dims = range(2, num_spatial_dims + 2)

        filter_dim = tensor_shape.dimension_at_index(filter_shape,
                                                     num_spatial_dims)
        if not (input_channels_dim % filter_dim).is_compatible_with(0):
            raise ValueError(
                "number of input channels is not divisible by corresponding "
                "dimension of filter, {} % {} != 0".format(
                    input_channels_dim, filter_dim))

        strides, dilation_rate = nn_ops._get_strides_and_dilation_rate(
            num_spatial_dims, strides, dilation_rate)

        self.input_shape = input_shape
        self.filter_shape = filter_shape
        self.data_format = data_format
        self.strides = strides
        self.padding = padding
        self.name = name
        self.dilation_rate = dilation_rate
        self.conv_op = nn_ops._WithSpaceToBatch(
            input_shape,
            dilation_rate=dilation_rate,
            padding=padding,
            build_op=self._build_op,
            filter_shape=filter_shape,
            spatial_dims=spatial_dims,
            data_format=data_format,
        )
    def from_fields(cls,
                    fields,
                    shape=(),
                    nrows=None,
                    row_partitions=None,
                    validate=False):
        """Creates a `StructuredTensor` from a dictionary of fields.

    Args:
      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
        `StructuredTensor`, providing the values for individual fields in each
        structure.  If `shape.rank > 0`, then every tensor in `fields` must have
        the same shape in the first `shape.rank` dimensions; and that shape must
        be compatible with `shape`; and
        `result[i1...iN][key] = fields[key][i1...iN]` (where `N==shape.rank`).
      shape: A `TensorShape`: static information about the shape of the
        `StructuredTensor`.  Must have a known `rank`.  Defaults to scalar
        shape (i.e. `rank=0`).
      nrows: scalar integer tensor containing the number of rows in this
        `StructuredTensor`.  Should only be specified if `shape.rank > 0`.
        Default value is inferred from the `fields` values.  If `fields` is
        empty, then this must be specified.
      row_partitions: A list of `RowPartition`s describing the (possibly ragged)
        shape of this `StructuredTensor`.  Should only be specified if
        `shape.rank > 1`.  Default value is inferred from the `fields` values.
        If `fields` is empty, then this must be specified.
      validate: If true, then add runtime validation ops that check that the
        field values all have compatible shapes in the outer `shape.rank`
        dimensions.

    Returns:
      A `StructuredTensor`.

    Examples:

      >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
      <StructuredTensor(fields={
                            x: tf.Tensor(1, shape=(), dtype=int32),
                            y: tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
                        shape=())>

      >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
      ...                              shape=[2])
      <StructuredTensor(fields={
                            bar: tf.Tensor([3 4], shape=(2,), dtype=int32),
                            foo: tf.Tensor([1 2], shape=(2,), dtype=int32)},
                        shape=(2,))>

    """
        shape = tensor_shape.as_shape(shape)
        rank = shape.rank
        if rank is None:
            raise ValueError("StructuredTensor's shape must have known rank.")
        if not isinstance(fields, dict):
            raise TypeError('fields must be a dictionary, got %s' %
                            type(fields).__name__)
        if rank < 2 and row_partitions:
            raise ValueError(
                'row_partitions must be None or [] if shape.rank<2')
        if rank == 0 and nrows is not None:
            raise ValueError('nrows must be None if shape.rank==0')
        if row_partitions is not None:
            row_partitions = tuple(row_partitions)
            if len(row_partitions) != max(0, rank - 1):
                raise ValueError('len(row_partitions) must be shape.rank-1')
        elif rank < 2:
            row_partitions = ()

        fields = dict(fields)  # Make a private copy.
        with ops.name_scope(None, 'StructuredTensor', fields.values()):

            # Validate keys and convert field values to tensors.
            for key, value in fields.items():
                if not isinstance(key, str):
                    raise TypeError('Unexpected type for key in `fields`: %r' %
                                    key)
                if not _FIELD_NAME_RE.match(key):
                    raise ValueError(
                        'Field name %r is not currently allowed.' % key)
                fields[key] = _convert_to_structured_field_value(value)

            # Determine dtype for row_partitions and nrows.
            shape_dtype = _find_shape_dtype(fields, nrows, row_partitions)
            if nrows is not None:
                nrows = ops.convert_to_tensor(nrows, shape_dtype)

            # Get the static TensorShape for this StructuredTensor.
            if rank > 0:
                for key, value in fields.items():
                    if not shape.is_compatible_with(value.shape[:rank]):
                        raise ValueError(
                            'Field {} has shape {}, which is incompatible '
                            'with the shape that was specified or inferred '
                            'from other fields: {}'.format(
                                key, value.shape[:rank], shape))
                    shape = shape.merge_with(value.shape[:rank])

            if rank == 1:
                # Find a consistent value for `nrows`.
                static_nrows = tensor_shape.dimension_at_index(shape, 0)
                for value in fields.values():
                    nrows, static_nrows = _merge_nrows(nrows, static_nrows,
                                                       value, shape_dtype,
                                                       validate)
                if nrows is None:
                    if static_nrows.value is None:
                        raise ValueError('nrows must be specified if rank==1 '
                                         'and `fields` is empty.')
                    else:
                        nrows = constant_op.constant(static_nrows.value,
                                                     shape_dtype)

            if rank > 1:
                # Find a consistent list of RowPartitions.
                for value in fields.values():
                    row_partitions = _merge_row_partitions(
                        row_partitions, value, rank, shape_dtype, validate)
                if row_partitions is None:
                    if not shape.is_fully_defined():
                        raise ValueError(
                            'row_partitions must be specified if rank>1 '
                            'and `fields` is empty.')
                    else:
                        row_partitions = _row_partitions_for_uniform_shape(
                            np.array(shape.as_list(),
                                     dtype=shape_dtype.as_numpy_dtype),
                            shape.rank)
                assert len(row_partitions) == rank - 1
                nrows = row_partitions[0].nrows()
                # Update all field values to use the shared RowPartition objects.
                fields = dict([(k, _replace_row_partitions(v, row_partitions))
                               for (k, v) in fields.items()])

        return cls(fields,
                   shape,
                   nrows,
                   row_partitions,
                   internal=_structured_tensor_factory_key)