Example #1
0
    def __init__(self,
                 block_operators,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=None,
                 name=None):
        # Validate operators.
        check_ops.assert_proper_iterable(block_operators)
        block_operators = list(block_operators)
        if not block_operators:
            raise ValueError(
                "Expected a non-empty list of block operators. Found: %s" %
                block_operators)
        self._block_operators = block_operators

        # Validate dtype.
        dtype = block_operators[0].dtype
        for operator in block_operators:
            if operator.dtype != dtype:
                name_type = (str((o.name, o.dtype)) for o in block_operators)
                raise TypeError(
                    "Expected all operators to have the same dtype. Found %s" %
                    "   ".join(name_type))

        # Auto-set and check hints.
        if all(operator.is_non_singular for operator in block_operators):
            if is_non_singular is False:
                raise ValueError(
                    "The assembly of non-singular block operators is always"
                    " non-singular.")
            is_non_singular = True

        # TODO: Check other hints.

        # Initialization.
        graph_parents = []
        for operator in block_operators:
            graph_parents.extend(operator.graph_parents)

        if name is None:
            name = "LinearOperatorBlockDiag_of_"
            name += "_".join(operator.name for operator in block_operators)

        with ops.name_scope(name, values=graph_parents):
            super(LinearOperatorBlockDiag,
                  self).__init__(dtype=dtype,
                                 graph_parents=graph_parents,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=is_square,
                                 name=name)
 def test_list_does_not_raise(self):
   list_of_stuff = [
       constant_op.constant([11, 22]), constant_op.constant([1, 2])
   ]
   check_ops.assert_proper_iterable(list_of_stuff)
 def test_non_iterable_object_raises(self):
   non_iterable = 1234
   with self.assertRaisesRegexp(TypeError, "to be iterable"):
     check_ops.assert_proper_iterable(non_iterable)
 def test_single_string_raises(self):
   mystr = "hello"
   with self.assertRaisesRegexp(TypeError, "proper"):
     check_ops.assert_proper_iterable(mystr)
 def test_single_ndarray_raises(self):
   array = np.array([1, 2, 3])
   with self.assertRaisesRegexp(TypeError, "proper"):
     check_ops.assert_proper_iterable(array)
 def test_single_sparse_tensor_raises(self):
   ten = sparse_tensor.SparseTensor(
       indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
   with self.assertRaisesRegexp(TypeError, "proper"):
     check_ops.assert_proper_iterable(ten)
  def __init__(self,
               operators,
               is_non_singular=None,
               is_self_adjoint=None,
               is_positive_definite=None,
               is_square=None,
               name=None):
    r"""Initialize a `LinearOperatorKronecker`.

    `LinearOperatorKronecker` is initialized with a list of operators
    `[op_1,...,op_J]`.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape, representing the Kronecker
        factors.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix\
            #Extension_for_non_symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_x_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty.
    """
    # Validate operators.
    check_ops.assert_proper_iterable(operators)
    operators = list(operators)
    if not operators:
      raise ValueError(
          "Expected a list of >=1 operators. Found: %s" % operators)
    self._operators = operators

    # Validate dtype.
    dtype = operators[0].dtype
    for operator in operators:
      if operator.dtype != dtype:
        name_type = (str((o.name, o.dtype)) for o in operators)
        raise TypeError(
            "Expected all operators to have the same dtype.  Found %s"
            % "   ".join(name_type))

    # Auto-set and check hints.
    # A Kronecker product is invertible, if and only if all factors are
    # invertible.
    if all(operator.is_non_singular for operator in operators):
      if is_non_singular is False:
        raise ValueError(
            "The Kronecker product of non-singular operators is always "
            "non-singular.")
      is_non_singular = True

    if all(operator.is_self_adjoint for operator in operators):
      if is_self_adjoint is False:
        raise ValueError(
            "The Kronecker product of self-adjoint operators is always "
            "self-adjoint.")
      is_self_adjoint = True

    # The eigenvalues of a Kronecker product are equal to the products of eigen
    # values of the corresponding factors.
    if all(operator.is_positive_definite for operator in operators):
      if is_positive_definite is False:
        raise ValueError("The Kronecker product of positive-definite operators "
                         "is always positive-definite.")
      is_positive_definite = True

    # Initialization.
    graph_parents = []
    for operator in operators:
      graph_parents.extend(operator.graph_parents)

    if name is None:
      name = operators[0].name
      for operator in operators[1:]:
        name += "_x_" + operator.name
    with ops.name_scope(name, values=graph_parents):
      super(LinearOperatorKronecker, self).__init__(
          dtype=dtype,
          graph_parents=graph_parents,
          is_non_singular=is_non_singular,
          is_self_adjoint=is_self_adjoint,
          is_positive_definite=is_positive_definite,
          is_square=is_square,
          name=name)
def add_operators(operators,
                  operator_name=None,
                  addition_tiers=None,
                  name=None):
    """Efficiently add one or more linear operators.

  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
  operators `[B1, B2,...]` such that

  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```

  The operators `Bk` result by adding some of the `Ak`, as allowed by
  `addition_tiers`.

  Example of efficient adding of diagonal operators.

  ```python
  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")

  # Use two tiers, the first contains an Adder that returns Diag.  Since both
  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
  # used.
  addition_tiers = [
      [_AddAndReturnDiag()],
      [_AddAndReturnMatrix()]]
  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)

  len(B_list)
  ==> 1

  B_list[0].__class__.__name__
  ==> 'LinearOperatorDiag'

  B_list[0].to_dense()
  ==> [[3., 0.],
       [0., 3.]]

  B_list[0].name
  ==> 'Add/A1__A2/'
  ```

  Args:
    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
      and range dimensions, and broadcastable batch shapes.
    operator_name:  String name for returned `LinearOperator`.  Defaults to
      concatenation of "Add/A__B/" that indicates the order of addition steps.
    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
      is a list of `Adder` objects.  This function attempts to do all additions
      in tier `i` before trying tier `i + 1`.
    name:  A name for this `Op`.  Defaults to `add_operators`.

  Returns:
    Subclass of `LinearOperator`.  Class and order of addition may change as new
      (and better) addition strategies emerge.

  Raises:
    ValueError:  If `operators` argument is empty.
    ValueError:  If shapes are incompatible.
  """
    # Default setting
    if addition_tiers is None:
        addition_tiers = _DEFAULT_ADDITION_TIERS

    # Argument checking.
    check_ops.assert_proper_iterable(operators)
    operators = list(reversed(operators))
    if len(operators) < 1:
        raise ValueError(
            f"Argument `operators` must contain at least one operator. "
            f"Received: {operators}.")
    if not all(
            isinstance(op, linear_operator.LinearOperator)
            for op in operators):
        raise TypeError(
            f"Argument `operators` must contain only LinearOperator instances. "
            f"Received: {operators}.")
    _static_check_for_same_dimensions(operators)
    _static_check_for_broadcastable_batch_shape(operators)

    graph_parents = []
    for operator in operators:
        graph_parents.extend(operator.graph_parents)

    with ops.name_scope(name or "add_operators", values=graph_parents):

        # Additions done in one of the tiers.  Try tier 0, 1,...
        ops_to_try_at_next_tier = list(operators)
        for tier in addition_tiers:
            ops_to_try_at_this_tier = ops_to_try_at_next_tier
            ops_to_try_at_next_tier = []
            while ops_to_try_at_this_tier:
                op1 = ops_to_try_at_this_tier.pop()
                op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier,
                                                  tier)
                if op2 is not None:
                    # Will try to add the result of this again at this same tier.
                    new_operator = adder.add(op1, op2, operator_name)
                    ops_to_try_at_this_tier.append(new_operator)
                else:
                    ops_to_try_at_next_tier.append(op1)

        return ops_to_try_at_next_tier
    def __init__(self,
                 operators,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=None,
                 name="LinearOperatorBlockLowerTriangular"):
        r"""Initialize a `LinearOperatorBlockLowerTriangular`.

    `LinearOperatorBlockLowerTriangular` is initialized with a list of lists of
    operators `[[op_0], [op_1, op_2], [op_3, op_4, op_5],...]`.

    Args:
      operators:  Iterable of iterables of `LinearOperator` objects, each with
        the same `dtype`. Each element of `operators` corresponds to a row-
        partition, in top-to-bottom order. The operators in each row-partition
        are filled in left-to-right. For example,
        `operators = [[op_0], [op_1, op_2], [op_3, op_4, op_5]]` creates a
        `LinearOperatorBlockLowerTriangular` with full block structure
        `[[op_0, 0, 0], [op_1, op_2, 0], [op_3, op_4, op_5]]`. The number of
        operators in the `i`th row must be equal to `i`, such that each operator
        falls on or below the diagonal of the blockwise structure.
        `LinearOperator`s that fall on the diagonal (the last elements of each
        row) must be square. The other `LinearOperator`s must have domain
        dimension equal to the domain dimension of the `LinearOperator`s in the
        same column-partition, and range dimension equal to the range dimension
        of the `LinearOperator`s in the same row-partition.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
        This will raise a `ValueError` if set to `False`.
      name: A name for this `LinearOperator`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty, contains an erroneous number of
        elements, or contains operators with incompatible shapes.
    """
        parameters = dict(operators=operators,
                          is_non_singular=is_non_singular,
                          is_self_adjoint=is_self_adjoint,
                          is_positive_definite=is_positive_definite,
                          is_square=is_square,
                          name=name)

        # Validate operators.
        check_ops.assert_proper_iterable(operators)
        for row in operators:
            check_ops.assert_proper_iterable(row)
        operators = [list(row) for row in operators]

        if not operators:
            raise ValueError(
                f"Argument `operators` must be a list of >=1 operators. "
                f"Received: {operators}.")
        self._operators = operators
        self._diagonal_operators = [row[-1] for row in operators]

        dtype = operators[0][0].dtype
        self._validate_dtype(dtype)
        is_non_singular = self._validate_non_singular(is_non_singular)
        self._validate_num_operators()
        self._validate_operator_dimensions()
        is_square = self._validate_square(is_square)
        with ops.name_scope(name):
            super(LinearOperatorBlockLowerTriangular,
                  self).__init__(dtype=dtype,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=is_square,
                                 parameters=parameters,
                                 name=name)
def add_operators(operators,
                  operator_name=None,
                  addition_tiers=None,
                  name=None):
  """Efficiently add one or more linear operators.

  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
  operators `[B1, B2,...]` such that

  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```

  The operators `Bk` result by adding some of the `Ak`, as allowed by
  `addition_tiers`.

  Example of efficient adding of diagonal operators.

  ```python
  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")

  # Use two tiers, the first contains an Adder that returns Diag.  Since both
  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
  # used.
  addition_tiers = [
      [_AddAndReturnDiag()],
      [_AddAndReturnMatrix()]]
  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)

  len(B_list)
  ==> 1

  B_list[0].__class__.__name__
  ==> 'LinearOperatorDiag'

  B_list[0].to_dense()
  ==> [[3., 0.],
       [0., 3.]]

  B_list[0].name
  ==> 'Add/A1__A2/'
  ```

  Args:
    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
      and range dimensions, and broadcastable batch shapes.
    operator_name:  String name for returned `LinearOperator`.  Defaults to
      concatenation of "Add/A__B/" that indicates the order of addition steps.
    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
      is a list of `Adder` objects.  This function attempts to do all additions
      in tier `i` before trying tier `i + 1`.
    name:  A name for this `Op`.  Defaults to `add_operators`.

  Returns:
    Subclass of `LinearOperator`.  Class and order of addition may change as new
      (and better) addition strategies emerge.

  Raises:
    ValueError:  If `operators` argument is empty.
    ValueError:  If shapes are incompatible.
  """
  # Default setting
  if addition_tiers is None:
    addition_tiers = _DEFAULT_ADDITION_TIERS

  # Argument checking.
  check_ops.assert_proper_iterable(operators)
  operators = list(reversed(operators))
  if len(operators) < 1:
    raise ValueError(
        "Argument 'operators' must contain at least one operator.  "
        "Found: %s" % operators)
  if not all(
      isinstance(op, linear_operator.LinearOperator) for op in operators):
    raise TypeError(
        "Argument 'operators' must contain only LinearOperator instances.  "
        "Found: %s" % operators)
  _static_check_for_same_dimensions(operators)
  _static_check_for_broadcastable_batch_shape(operators)

  graph_parents = []
  for operator in operators:
    graph_parents.extend(operator.graph_parents)

  with ops.name_scope(name or "add_operators", values=graph_parents):

    # Additions done in one of the tiers.  Try tier 0, 1,...
    ops_to_try_at_next_tier = list(operators)
    for tier in addition_tiers:
      ops_to_try_at_this_tier = ops_to_try_at_next_tier
      ops_to_try_at_next_tier = []
      while ops_to_try_at_this_tier:
        op1 = ops_to_try_at_this_tier.pop()
        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
        if op2 is not None:
          # Will try to add the result of this again at this same tier.
          new_operator = adder.add(op1, op2, operator_name)
          ops_to_try_at_this_tier.append(new_operator)
        else:
          ops_to_try_at_next_tier.append(op1)

    return ops_to_try_at_next_tier
def broadcast_matrix_batch_dims(batch_matrices, name=None):
  """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random_normal(shape=(2, 3, 1, 4, 4))
  y = tf.random_normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc.shape
  ==> (2, 3, 2, 4, 4)

  y_bc.shape
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
  with ops.name_scope(
      name or "broadcast_matrix_batch_dims", values=batch_matrices):
    check_ops.assert_proper_iterable(batch_matrices)
    batch_matrices = list(batch_matrices)

    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = ops.convert_to_tensor(mat)
      assert_is_batch_matrix(batch_matrices[i])

    if len(batch_matrices) < 2:
      return batch_matrices

    # Try static broadcasting.
    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
    # E.g. if batch_matrices = [x, y], with
    # x.shape =    [2, j, k]  (batch shape =    [2])
    # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
    # ==> bcast_batch_shape = [3, 2]
    bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_static_shape(
          bcast_batch_shape,
          mat.get_shape()[:-2])
    if bcast_batch_shape.is_fully_defined():
      # The [1, 1] at the end will broadcast with anything.
      bcast_shape = bcast_batch_shape.concatenate([1, 1])
      for i, mat in enumerate(batch_matrices):
        if mat.get_shape()[:-2] != bcast_batch_shape:
          batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
      return batch_matrices

    # Since static didn't work, do dynamic, which always copies data.
    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
          bcast_batch_shape,
          array_ops.shape(mat)[:-2])
    bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)

    return batch_matrices
    def __init__(self,
                 operators,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=True,
                 name=None):
        r"""Initialize a `LinearOperatorBlockDiag`.

    `LinearOperatorBlockDiag` is initialized with a list of operators
    `[op_1,...,op_J]`.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
        This is true by default, and will raise a `ValueError` otherwise.
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_o_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty or are non-square.
    """
        # Validate operators.
        check_ops.assert_proper_iterable(operators)
        operators = list(operators)
        if not operators:
            raise ValueError(
                "Expected a non-empty list of operators. Found: %s" %
                operators)
        self._operators = operators

        # Define diagonal operators, for functions that are shared across blockwise
        # `LinearOperator` types.
        self._diagonal_operators = operators

        # Validate dtype.
        dtype = operators[0].dtype
        for operator in operators:
            if operator.dtype != dtype:
                name_type = (str((o.name, o.dtype)) for o in operators)
                raise TypeError(
                    "Expected all operators to have the same dtype.  Found %s"
                    % "   ".join(name_type))

        # Auto-set and check hints.
        if all(operator.is_non_singular for operator in operators):
            if is_non_singular is False:
                raise ValueError(
                    "The direct sum of non-singular operators is always non-singular."
                )
            is_non_singular = True

        if all(operator.is_self_adjoint for operator in operators):
            if is_self_adjoint is False:
                raise ValueError(
                    "The direct sum of self-adjoint operators is always self-adjoint."
                )
            is_self_adjoint = True

        if all(operator.is_positive_definite for operator in operators):
            if is_positive_definite is False:
                raise ValueError(
                    "The direct sum of positive definite operators is always "
                    "positive definite.")
            is_positive_definite = True

        if not (is_square and all(operator.is_square
                                  for operator in operators)):
            raise ValueError(
                "Can only represent a block diagonal of square matrices.")

        # Initialization.
        graph_parents = []
        for operator in operators:
            graph_parents.extend(operator.graph_parents)

        if name is None:
            # Using ds to mean direct sum.
            name = "_ds_".join(operator.name for operator in operators)
        with ops.name_scope(name, values=graph_parents):
            super(LinearOperatorBlockDiag,
                  self).__init__(dtype=dtype,
                                 graph_parents=None,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=True,
                                 name=name)

        # TODO(b/143910018) Remove graph_parents in V3.
        self._set_graph_parents(graph_parents)
  def __init__(self,
               operators,
               is_non_singular=None,
               is_self_adjoint=None,
               is_positive_definite=None,
               name=None):
    r"""Initialize a `LinearOperatorComposition`.

    `LinearOperatorComposition` is initialized with a list of operators
    `[op_1,...,op_J]`.  For the `apply` method to be well defined, the
    composition `op_i.apply(op_{i+1}(x))` must be defined.  Other methods have
    similar constraints.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix\
            #Extension_for_non_symmetric_matrices
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_o_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty.
    """
    # Validate operators.
    check_ops.assert_proper_iterable(operators)
    operators = list(operators)
    if not operators:
      raise ValueError(
          "Expected a non-empty list of operators. Found: %s" % operators)
    self._operators = operators

    # Validate dtype.
    dtype = operators[0].dtype
    for operator in operators:
      if operator.dtype != dtype:
        name_type = (str((o.name, o.dtype)) for o in operators)
        raise TypeError(
            "Expected all operators to have the same dtype.  Found %s"
            % "   ".join(name_type))

    # Auto-set and check hints.
    if all(operator.is_non_singular for operator in operators):
      if is_non_singular is False:
        raise ValueError(
            "The composition of non-singular operators is always non-singular.")
      is_non_singular = True

    # Initialization.
    graph_parents = []
    for operator in operators:
      graph_parents.extend(operator.graph_parents)

    if name is None:
      name = "_o_".join(operator.name for operator in operators)
    with ops.name_scope(name, values=graph_parents):
      super(LinearOperatorComposition, self).__init__(
          dtype=dtype,
          graph_parents=graph_parents,
          is_non_singular=is_non_singular,
          is_self_adjoint=is_self_adjoint,
          is_positive_definite=is_positive_definite,
          name=name)
Example #14
0
 def test_generator_does_not_raise(self):
   generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant(
       [1, 2]))
   check_ops.assert_proper_iterable(generator_of_stuff)
    def __init__(self,
                 operators,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=None,
                 name=None):
        r"""Initialize a `LinearOperatorKronecker`.

    `LinearOperatorKronecker` is initialized with a list of operators
    `[op_1,...,op_J]`.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape, representing the Kronecker
        factors.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix\
            #Extension_for_non_symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_x_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty.
    """
        parameters = dict(operators=operators,
                          is_non_singular=is_non_singular,
                          is_self_adjoint=is_self_adjoint,
                          is_positive_definite=is_positive_definite,
                          is_square=is_square,
                          name=name)

        # Validate operators.
        check_ops.assert_proper_iterable(operators)
        operators = list(operators)
        if not operators:
            raise ValueError(
                f"Argument `operators` must be a list of >=1 operators. "
                f"Received: {operators}.")
        self._operators = operators

        # Validate dtype.
        dtype = operators[0].dtype
        for operator in operators:
            if operator.dtype != dtype:
                name_type = (str((o.name, o.dtype)) for o in operators)
                raise TypeError(
                    f"Expected every operation in argument `operators` to have the "
                    f"same dtype. Received {list(name_type)}.")

        # Auto-set and check hints.
        # A Kronecker product is invertible, if and only if all factors are
        # invertible.
        if all(operator.is_non_singular for operator in operators):
            if is_non_singular is False:
                raise ValueError(
                    f"The Kronecker product of non-singular operators is always "
                    f"non-singular. Expected argument `is_non_singular` to be True. "
                    f"Received: {is_non_singular}.")
            is_non_singular = True

        if all(operator.is_self_adjoint for operator in operators):
            if is_self_adjoint is False:
                raise ValueError(
                    f"The Kronecker product of self-adjoint operators is always "
                    f"self-adjoint. Expected argument `is_self_adjoint` to be True. "
                    f"Received: {is_self_adjoint}.")
            is_self_adjoint = True

        # The eigenvalues of a Kronecker product are equal to the products of eigen
        # values of the corresponding factors.
        if all(operator.is_positive_definite for operator in operators):
            if is_positive_definite is False:
                raise ValueError(
                    f"The Kronecker product of positive-definite operators is always "
                    f"positive-definite. Expected argument `is_positive_definite` to "
                    f"be True. Received: {is_positive_definite}.")
            is_positive_definite = True

        # Initialization.
        graph_parents = []
        for operator in operators:
            graph_parents.extend(operator.graph_parents)

        if name is None:
            name = operators[0].name
            for operator in operators[1:]:
                name += "_x_" + operator.name
        with ops.name_scope(name, values=graph_parents):
            super(LinearOperatorKronecker,
                  self).__init__(dtype=dtype,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=is_square,
                                 parameters=parameters,
                                 name=name)
        # TODO(b/143910018) Remove graph_parents in V3.
        self._set_graph_parents(graph_parents)
def broadcast_matrix_batch_dims(batch_matrices, name=None):
  """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random_normal(shape=(2, 3, 1, 4, 4))
  y = tf.random_normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc.shape
  ==> (2, 3, 2, 4, 4)

  y_bc.shape
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
  with ops.name_scope(
      name or "broadcast_matrix_batch_dims", values=batch_matrices):
    check_ops.assert_proper_iterable(batch_matrices)
    batch_matrices = list(batch_matrices)

    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = ops.convert_to_tensor(mat)
      assert_is_batch_matrix(batch_matrices[i])

    if len(batch_matrices) < 2:
      return batch_matrices

    # Try static broadcasting.
    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
    # E.g. if batch_matrices = [x, y], with
    # x.shape =    [2, j, k]  (batch shape =    [2])
    # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
    # ==> bcast_batch_shape = [3, 2]
    bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_static_shape(
          bcast_batch_shape,
          mat.get_shape()[:-2])
    if bcast_batch_shape.is_fully_defined():
      # The [1, 1] at the end will broadcast with anything.
      bcast_shape = bcast_batch_shape.concatenate([1, 1])
      for i, mat in enumerate(batch_matrices):
        if mat.get_shape()[:-2] != bcast_batch_shape:
          batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
      return batch_matrices

    # Since static didn't work, do dynamic, which always copies data.
    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
          bcast_batch_shape,
          array_ops.shape(mat)[:-2])
    bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)

    return batch_matrices
Example #17
0
 def test_single_tensor_raises(self):
   tensor = constant_op.constant(1)
   with self.assertRaisesRegexp(TypeError, "proper"):
     check_ops.assert_proper_iterable(tensor)
  def __init__(self,
               operators,
               is_non_singular=None,
               is_self_adjoint=None,
               is_positive_definite=None,
               is_square=None,
               name=None):
    r"""Initialize a `LinearOperatorComposition`.

    `LinearOperatorComposition` is initialized with a list of operators
    `[op_1,...,op_J]`.  For the `matmul` method to be well defined, the
    composition `op_i.matmul(op_{i+1}(x))` must be defined.  Other methods have
    similar constraints.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_o_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty.
    """
    parameters = dict(
        operators=operators,
        is_non_singular=is_non_singular,
        is_self_adjoint=is_self_adjoint,
        is_positive_definite=is_positive_definite,
        is_square=is_square,
        name=name)

    # Validate operators.
    check_ops.assert_proper_iterable(operators)
    operators = list(operators)
    if not operators:
      raise ValueError(
          "Expected a non-empty list of operators. Found: %s" % operators)
    self._operators = operators

    # Validate dtype.
    dtype = operators[0].dtype
    for operator in operators:
      if operator.dtype != dtype:
        name_type = (str((o.name, o.dtype)) for o in operators)
        raise TypeError(
            "Expected all operators to have the same dtype.  Found %s"
            % "   ".join(name_type))

    # Auto-set and check hints.
    if all(operator.is_non_singular for operator in operators):
      if is_non_singular is False:
        raise ValueError(
            "The composition of non-singular operators is always non-singular.")
      is_non_singular = True

    # Initialization.

    if name is None:
      name = "_o_".join(operator.name for operator in operators)
    with ops.name_scope(name):
      super(LinearOperatorComposition, self).__init__(
          dtype=dtype,
          is_non_singular=is_non_singular,
          is_self_adjoint=is_self_adjoint,
          is_positive_definite=is_positive_definite,
          is_square=is_square,
          parameters=parameters,
          name=name)
  def __init__(self,
               operators,
               is_non_singular=None,
               is_self_adjoint=None,
               is_positive_definite=None,
               is_square=True,
               name=None):
    r"""Initialize a `LinearOperatorBlockDiag`.

    `LinearOperatorBlockDiag` is initialized with a list of operators
    `[op_1,...,op_J]`.

    Args:
      operators:  Iterable of `LinearOperator` objects, each with
        the same `dtype` and composable shape.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
        This is true by default, and will raise a `ValueError` otherwise.
      name: A name for this `LinearOperator`.  Default is the individual
        operators names joined with `_o_`.

    Raises:
      TypeError:  If all operators do not have the same `dtype`.
      ValueError:  If `operators` is empty or are non-square.
    """
    # Validate operators.
    check_ops.assert_proper_iterable(operators)
    operators = list(operators)
    if not operators:
      raise ValueError(
          "Expected a non-empty list of operators. Found: %s" % operators)
    self._operators = operators

    # Validate dtype.
    dtype = operators[0].dtype
    for operator in operators:
      if operator.dtype != dtype:
        name_type = (str((o.name, o.dtype)) for o in operators)
        raise TypeError(
            "Expected all operators to have the same dtype.  Found %s"
            % "   ".join(name_type))

    # Auto-set and check hints.
    if all(operator.is_non_singular for operator in operators):
      if is_non_singular is False:
        raise ValueError(
            "The direct sum of non-singular operators is always non-singular.")
      is_non_singular = True

    if all(operator.is_self_adjoint for operator in operators):
      if is_self_adjoint is False:
        raise ValueError(
            "The direct sum of self-adjoint operators is always self-adjoint.")
      is_self_adjoint = True

    if all(operator.is_positive_definite for operator in operators):
      if is_positive_definite is False:
        raise ValueError(
            "The direct sum of positive definite operators is always "
            "positive definite.")
      is_positive_definite = True

    if not (is_square and all(operator.is_square for operator in operators)):
      raise ValueError(
          "Can only represent a block diagonal of square matrices.")

    # Initialization.
    graph_parents = []
    for operator in operators:
      graph_parents.extend(operator.graph_parents)

    if name is None:
      # Using ds to mean direct sum.
      name = "_ds_".join(operator.name for operator in operators)
    with ops.name_scope(name, values=graph_parents):
      super(LinearOperatorBlockDiag, self).__init__(
          dtype=dtype,
          graph_parents=graph_parents,
          is_non_singular=is_non_singular,
          is_self_adjoint=is_self_adjoint,
          is_positive_definite=is_positive_definite,
          is_square=True,
          name=name)