예제 #1
0
def _slice_single_param(param, param_ndims_to_matrix_ndims, slices,
                        batch_shape):
    """Slices into the batch shape of a single parameter.

  Args:
    param: The original parameter to slice; either a `Tensor` or an object
      with batch shape (LinearOperator).
    param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
      inferring matrix shape of the `LinearOperator`. For non-Tensor
      parameters, this is the number of this param's batch dimensions used by
      the matrix shape of the parent object.
    slices: iterable of slices received by `__getitem__`.
    batch_shape: The parameterized object's batch shape `Tensor`.

  Returns:
    new_param: Instance of the same type as `param`, batch-sliced according to
      `slices`.
  """
    # Broadcast the parammeter to have full batch rank.
    param = _broadcast_parameter_with_batch_shape(
        param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))

    if hasattr(param, 'batch_shape_tensor'):
        param_batch_shape = param.batch_shape_tensor()
    else:
        param_batch_shape = prefer_static.shape(param)
    # Truncate by param_ndims_to_matrix_ndims
    param_batch_rank = array_ops.size(param_batch_shape)
    param_batch_shape = param_batch_shape[:(param_batch_rank -
                                            param_ndims_to_matrix_ndims)]

    # At this point the param should have full batch rank, *unless* it's an
    # atomic object like `tfb.Identity()` incapable of having any batch rank.
    if (ops.get_static_value(array_ops.size(batch_shape)) != 0
            and ops.get_static_value(array_ops.size(param_batch_shape)) == 0):
        return param
    param_slices = _sanitize_slices(slices,
                                    intended_shape=batch_shape,
                                    deficient_shape=param_batch_shape)

    # Extend `param_slices` (which represents slicing into the
    # parameter's batch shape) with the parameter's event ndims. For example, if
    # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
    if param_ndims_to_matrix_ndims > 0:
        if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]:
            param_slices.append(Ellipsis)
        param_slices = param_slices + [slice(None)
                                       ] * param_ndims_to_matrix_ndims
    return param.__getitem__(tuple(param_slices))
예제 #2
0
def _prefer_static_where(condition, x, y):
    args = [condition, x, y]
    constant_args = [ops.get_static_value(a) for a in args]
    # Do this statically.
    if all(arg is not None for arg in constant_args):
        condition_, x_, y_ = constant_args
        return np.where(condition_, x_, y_)
    return array_ops.where(condition, x, y)
예제 #3
0
 def override_body_fn(args, _):
     c = cond(*args)
     sc = ops.get_static_value(c)
     if sc is None:
         args = lax.cond(c, args, lambda args: body(*args), args,
                         lambda args: args)
     elif sc:
         args = body(*args)
     return args, ()
예제 #4
0
def _prefer_static_concat_shape(first_shape, second_shape_int_list):
  """Concatenate a shape with a list of integers as statically as possible.

  Args:
    first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`,
      `first_shape.is_fully_defined()` must return `True`.
    second_shape_int_list: `list` of scalar integer `Tensor`s.

  Returns:
    `Tensor` representing concatenating `first_shape` and
      `second_shape_int_list` as statically as possible.
  """
  second_shape_int_list_static = [
      ops.get_static_value(s) for s in second_shape_int_list]
  if (isinstance(first_shape, tensor_shape.TensorShape) and
      all(s is not None for s in second_shape_int_list_static)):
    return first_shape.concatenate(second_shape_int_list_static)
  return prefer_static.concat([first_shape, second_shape_int_list], axis=0)
예제 #5
0
    def __init__(self,
                 num_rows,
                 num_columns=None,
                 batch_shape=None,
                 dtype=None,
                 is_non_singular=False,
                 is_self_adjoint=True,
                 is_positive_definite=False,
                 is_square=True,
                 assert_proper_shapes=False,
                 name="LinearOperatorZeros"):
        r"""Initialize a `LinearOperatorZeros`.

    The `LinearOperatorZeros` is initialized with arguments defining `dtype`
    and shape.

    This operator is able to broadcast the leading (batch) dimensions, which
    sometimes requires copying data.  If `batch_shape` is `None`, the operator
    can take arguments of any batch shape without copying.  See examples.

    Args:
      num_rows:  Scalar non-negative integer `Tensor`.  Number of rows in the
        corresponding zero matrix.
      num_columns:  Scalar non-negative integer `Tensor`.  Number of columns in
        the corresponding zero matrix. If `None`, defaults to the value of
        `num_rows`.
      batch_shape:  Optional `1-D` integer `Tensor`.  The shape of the leading
        dimensions.  If `None`, this operator has no leading dimensions.
      dtype:  Data type of the matrix that this operator represents.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      assert_proper_shapes:  Python `bool`.  If `False`, only perform static
        checks that initialization and method arguments have proper shape.
        If `True`, and static checks are inconclusive, add asserts to the graph.
      name: A name for this `LinearOperator`

    Raises:
      ValueError:  If `num_rows` is determined statically to be non-scalar, or
        negative.
      ValueError:  If `num_columns` is determined statically to be non-scalar,
        or negative.
      ValueError:  If `batch_shape` is determined statically to not be 1-D, or
        negative.
      ValueError:  If any of the following is not `True`:
        `{is_self_adjoint, is_non_singular, is_positive_definite}`.
    """
        parameters = dict(num_rows=num_rows,
                          num_columns=num_columns,
                          batch_shape=batch_shape,
                          dtype=dtype,
                          is_non_singular=is_non_singular,
                          is_self_adjoint=is_self_adjoint,
                          is_positive_definite=is_positive_definite,
                          is_square=is_square,
                          assert_proper_shapes=assert_proper_shapes,
                          name=name)

        dtype = dtype or dtypes.float32
        self._assert_proper_shapes = assert_proper_shapes

        with ops.name_scope(name):
            dtype = dtypes.as_dtype(dtype)
            if not is_self_adjoint and is_square:
                raise ValueError("A zero operator is always self adjoint.")
            if is_non_singular:
                raise ValueError("A zero operator is always singular.")
            if is_positive_definite:
                raise ValueError(
                    "A zero operator is always not positive-definite.")

            super(LinearOperatorZeros,
                  self).__init__(dtype=dtype,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=is_square,
                                 parameters=parameters,
                                 name=name)

            linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
            linear_operator_util.assert_not_ref_type(num_columns,
                                                     "num_columns")
            linear_operator_util.assert_not_ref_type(batch_shape,
                                                     "batch_shape")

            self._num_rows = linear_operator_util.shape_tensor(num_rows,
                                                               name="num_rows")
            self._num_rows_static = ops.get_static_value(self._num_rows)

            if num_columns is None:
                num_columns = num_rows

            self._num_columns = linear_operator_util.shape_tensor(
                num_columns, name="num_columns")
            self._num_columns_static = ops.get_static_value(self._num_columns)

            self._check_domain_range_possibly_add_asserts()

            if (self._num_rows_static is not None
                    and self._num_columns_static is not None):
                if is_square and self._num_rows_static != self._num_columns_static:
                    raise ValueError(
                        "LinearOperatorZeros initialized as is_square=True, but got "
                        "num_rows({}) != num_columns({})".format(
                            self._num_rows_static, self._num_columns_static))

            if batch_shape is None:
                self._batch_shape_arg = None
            else:
                self._batch_shape_arg = linear_operator_util.shape_tensor(
                    batch_shape, name="batch_shape_arg")
                self._batch_shape_static = ops.get_static_value(
                    self._batch_shape_arg)
                self._check_batch_shape_possibly_add_asserts()
예제 #6
0
    def __init__(self,
                 num_rows,
                 multiplier,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=True,
                 assert_proper_shapes=False,
                 name="LinearOperatorScaledIdentity"):
        r"""Initialize a `LinearOperatorScaledIdentity`.

    The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
    determines the size of each identity matrix, and a `multiplier`,
    which defines `dtype`, batch shape, and scale of each matrix.

    This operator is able to broadcast the leading (batch) dimensions.

    Args:
      num_rows:  Scalar non-negative integer `Tensor`.  Number of rows in the
        corresponding identity matrix.
      multiplier:  `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar).
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      assert_proper_shapes:  Python `bool`.  If `False`, only perform static
        checks that initialization and method arguments have proper shape.
        If `True`, and static checks are inconclusive, add asserts to the graph.
      name: A name for this `LinearOperator`

    Raises:
      ValueError:  If `num_rows` is determined statically to be non-scalar, or
        negative.
    """
        parameters = dict(num_rows=num_rows,
                          multiplier=multiplier,
                          is_non_singular=is_non_singular,
                          is_self_adjoint=is_self_adjoint,
                          is_positive_definite=is_positive_definite,
                          is_square=is_square,
                          assert_proper_shapes=assert_proper_shapes,
                          name=name)

        self._assert_proper_shapes = assert_proper_shapes

        with ops.name_scope(name, values=[multiplier, num_rows]):
            self._multiplier = linear_operator_util.convert_nonref_to_tensor(
                multiplier, name="multiplier")

            # Check and auto-set hints.
            if not np.issubdtype(self._multiplier.dtype, np.complexfloating):
                if is_self_adjoint is False:  # pylint: disable=g-bool-id-comparison
                    raise ValueError(
                        "A real diagonal operator is always self adjoint.")
                else:
                    is_self_adjoint = True

            if not is_square:
                raise ValueError("A ScaledIdentity operator is always square.")

            linear_operator_util.assert_not_ref_type(num_rows, "num_rows")

            super(LinearOperatorScaledIdentity,
                  self).__init__(dtype=self._multiplier.dtype,
                                 is_non_singular=is_non_singular,
                                 is_self_adjoint=is_self_adjoint,
                                 is_positive_definite=is_positive_definite,
                                 is_square=is_square,
                                 parameters=parameters,
                                 name=name)

            self._num_rows = linear_operator_util.shape_tensor(num_rows,
                                                               name="num_rows")
            self._num_rows_static = ops.get_static_value(self._num_rows)
            self._check_num_rows_possibly_add_asserts()
            self._num_rows_cast_to_dtype = _ops.cast(self._num_rows,
                                                     self.dtype)
            self._num_rows_cast_to_real_dtype = _ops.cast(
                self._num_rows, dtypes.real_dtype(self.dtype))
예제 #7
0
    def _solve_matmul_internal(self,
                               x,
                               solve_matmul_fn,
                               adjoint=False,
                               adjoint_arg=False):
        # We heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T
        # where vec stacks all the columns of the matrix under each other.
        # In our case, we use a variant of the lemma that is row-major
        # friendly: (A x B) * vec' X = vec' AXB^T
        # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
        # for a collection of kronecker products.
        # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
        # use the above to compute multiplications, solves with any composition of
        # transposes.
        output = x

        if adjoint_arg:
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)
        else:
            output = linalg.transpose(output)

        for o in reversed(self.operators):
            # Statically compute the reshape.
            if adjoint:
                operator_dimension = o.range_dimension_tensor()
            else:
                operator_dimension = o.domain_dimension_tensor()
            output_shape = _prefer_static_shape(output)

            if ops.get_static_value(operator_dimension) is not None:
                operator_dimension = ops.get_static_value(operator_dimension)
                if tensor_shape.TensorShape(
                        output.shape
                )[-2] is not None and tensor_shape.TensorShape(
                        output.shape)[-1] is not None:
                    dim = int(
                        tensor_shape.TensorShape(output.shape)[-2] *
                        output_shape[-1] // operator_dimension)
            else:
                dim = _ops.cast(output_shape[-2] * output_shape[-1] //
                                operator_dimension,
                                dtype=dtypes.int32)

            output_shape = _prefer_static_concat_shape(
                output_shape[:-2], [dim, operator_dimension])
            output = array_ops.reshape(output, shape=output_shape)

            # Conjugate because we are trying to compute A @ B^T, but
            # `LinearOperator` only supports `adjoint_arg`.
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)

            output = solve_matmul_fn(o,
                                     output,
                                     adjoint=adjoint,
                                     adjoint_arg=True)

        if adjoint_arg:
            col_dim = _prefer_static_shape(x)[-2]
        else:
            col_dim = _prefer_static_shape(x)[-1]

        if adjoint:
            row_dim = self.domain_dimension_tensor()
        else:
            row_dim = self.range_dimension_tensor()

        matrix_shape = [row_dim, col_dim]

        output = array_ops.reshape(
            output,
            _prefer_static_concat_shape(
                _prefer_static_shape(output)[:-2], matrix_shape))

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            if adjoint_arg:
                column_dim = tensor_shape.TensorShape(x.shape)[-2]
            else:
                column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output