def _slice_single_param(param, param_ndims_to_matrix_ndims, slices, batch_shape): """Slices into the batch shape of a single parameter. Args: param: The original parameter to slice; either a `Tensor` or an object with batch shape (LinearOperator). param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for inferring matrix shape of the `LinearOperator`. For non-Tensor parameters, this is the number of this param's batch dimensions used by the matrix shape of the parent object. slices: iterable of slices received by `__getitem__`. batch_shape: The parameterized object's batch shape `Tensor`. Returns: new_param: Instance of the same type as `param`, batch-sliced according to `slices`. """ # Broadcast the parammeter to have full batch rank. param = _broadcast_parameter_with_batch_shape( param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape)) if hasattr(param, 'batch_shape_tensor'): param_batch_shape = param.batch_shape_tensor() else: param_batch_shape = prefer_static.shape(param) # Truncate by param_ndims_to_matrix_ndims param_batch_rank = array_ops.size(param_batch_shape) param_batch_shape = param_batch_shape[:(param_batch_rank - param_ndims_to_matrix_ndims)] # At this point the param should have full batch rank, *unless* it's an # atomic object like `tfb.Identity()` incapable of having any batch rank. if (ops.get_static_value(array_ops.size(batch_shape)) != 0 and ops.get_static_value(array_ops.size(param_batch_shape)) == 0): return param param_slices = _sanitize_slices(slices, intended_shape=batch_shape, deficient_shape=param_batch_shape) # Extend `param_slices` (which represents slicing into the # parameter's batch shape) with the parameter's event ndims. For example, if # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`. if param_ndims_to_matrix_ndims > 0: if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]: param_slices.append(Ellipsis) param_slices = param_slices + [slice(None) ] * param_ndims_to_matrix_ndims return param.__getitem__(tuple(param_slices))
def _prefer_static_where(condition, x, y): args = [condition, x, y] constant_args = [ops.get_static_value(a) for a in args] # Do this statically. if all(arg is not None for arg in constant_args): condition_, x_, y_ = constant_args return np.where(condition_, x_, y_) return array_ops.where(condition, x, y)
def override_body_fn(args, _): c = cond(*args) sc = ops.get_static_value(c) if sc is None: args = lax.cond(c, args, lambda args: body(*args), args, lambda args: args) elif sc: args = body(*args) return args, ()
def _prefer_static_concat_shape(first_shape, second_shape_int_list): """Concatenate a shape with a list of integers as statically as possible. Args: first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`, `first_shape.is_fully_defined()` must return `True`. second_shape_int_list: `list` of scalar integer `Tensor`s. Returns: `Tensor` representing concatenating `first_shape` and `second_shape_int_list` as statically as possible. """ second_shape_int_list_static = [ ops.get_static_value(s) for s in second_shape_int_list] if (isinstance(first_shape, tensor_shape.TensorShape) and all(s is not None for s in second_shape_int_list_static)): return first_shape.concatenate(second_shape_int_list_static) return prefer_static.concat([first_shape, second_shape_int_list], axis=0)
def __init__(self, num_rows, num_columns=None, batch_shape=None, dtype=None, is_non_singular=False, is_self_adjoint=True, is_positive_definite=False, is_square=True, assert_proper_shapes=False, name="LinearOperatorZeros"): r"""Initialize a `LinearOperatorZeros`. The `LinearOperatorZeros` is initialized with arguments defining `dtype` and shape. This operator is able to broadcast the leading (batch) dimensions, which sometimes requires copying data. If `batch_shape` is `None`, the operator can take arguments of any batch shape without copying. See examples. Args: num_rows: Scalar non-negative integer `Tensor`. Number of rows in the corresponding zero matrix. num_columns: Scalar non-negative integer `Tensor`. Number of columns in the corresponding zero matrix. If `None`, defaults to the value of `num_rows`. batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading dimensions. If `None`, this operator has no leading dimensions. dtype: Data type of the matrix that this operator represents. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. is_positive_definite: Expect that this operator is positive definite, meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. assert_proper_shapes: Python `bool`. If `False`, only perform static checks that initialization and method arguments have proper shape. If `True`, and static checks are inconclusive, add asserts to the graph. name: A name for this `LinearOperator` Raises: ValueError: If `num_rows` is determined statically to be non-scalar, or negative. ValueError: If `num_columns` is determined statically to be non-scalar, or negative. ValueError: If `batch_shape` is determined statically to not be 1-D, or negative. ValueError: If any of the following is not `True`: `{is_self_adjoint, is_non_singular, is_positive_definite}`. """ parameters = dict(num_rows=num_rows, num_columns=num_columns, batch_shape=batch_shape, dtype=dtype, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, assert_proper_shapes=assert_proper_shapes, name=name) dtype = dtype or dtypes.float32 self._assert_proper_shapes = assert_proper_shapes with ops.name_scope(name): dtype = dtypes.as_dtype(dtype) if not is_self_adjoint and is_square: raise ValueError("A zero operator is always self adjoint.") if is_non_singular: raise ValueError("A zero operator is always singular.") if is_positive_definite: raise ValueError( "A zero operator is always not positive-definite.") super(LinearOperatorZeros, self).__init__(dtype=dtype, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, parameters=parameters, name=name) linear_operator_util.assert_not_ref_type(num_rows, "num_rows") linear_operator_util.assert_not_ref_type(num_columns, "num_columns") linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") self._num_rows = linear_operator_util.shape_tensor(num_rows, name="num_rows") self._num_rows_static = ops.get_static_value(self._num_rows) if num_columns is None: num_columns = num_rows self._num_columns = linear_operator_util.shape_tensor( num_columns, name="num_columns") self._num_columns_static = ops.get_static_value(self._num_columns) self._check_domain_range_possibly_add_asserts() if (self._num_rows_static is not None and self._num_columns_static is not None): if is_square and self._num_rows_static != self._num_columns_static: raise ValueError( "LinearOperatorZeros initialized as is_square=True, but got " "num_rows({}) != num_columns({})".format( self._num_rows_static, self._num_columns_static)) if batch_shape is None: self._batch_shape_arg = None else: self._batch_shape_arg = linear_operator_util.shape_tensor( batch_shape, name="batch_shape_arg") self._batch_shape_static = ops.get_static_value( self._batch_shape_arg) self._check_batch_shape_possibly_add_asserts()
def __init__(self, num_rows, multiplier, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, is_square=True, assert_proper_shapes=False, name="LinearOperatorScaledIdentity"): r"""Initialize a `LinearOperatorScaledIdentity`. The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which determines the size of each identity matrix, and a `multiplier`, which defines `dtype`, batch shape, and scale of each matrix. This operator is able to broadcast the leading (batch) dimensions. Args: num_rows: Scalar non-negative integer `Tensor`. Number of rows in the corresponding identity matrix. multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. is_positive_definite: Expect that this operator is positive definite, meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. assert_proper_shapes: Python `bool`. If `False`, only perform static checks that initialization and method arguments have proper shape. If `True`, and static checks are inconclusive, add asserts to the graph. name: A name for this `LinearOperator` Raises: ValueError: If `num_rows` is determined statically to be non-scalar, or negative. """ parameters = dict(num_rows=num_rows, multiplier=multiplier, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, assert_proper_shapes=assert_proper_shapes, name=name) self._assert_proper_shapes = assert_proper_shapes with ops.name_scope(name, values=[multiplier, num_rows]): self._multiplier = linear_operator_util.convert_nonref_to_tensor( multiplier, name="multiplier") # Check and auto-set hints. if not np.issubdtype(self._multiplier.dtype, np.complexfloating): if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison raise ValueError( "A real diagonal operator is always self adjoint.") else: is_self_adjoint = True if not is_square: raise ValueError("A ScaledIdentity operator is always square.") linear_operator_util.assert_not_ref_type(num_rows, "num_rows") super(LinearOperatorScaledIdentity, self).__init__(dtype=self._multiplier.dtype, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, parameters=parameters, name=name) self._num_rows = linear_operator_util.shape_tensor(num_rows, name="num_rows") self._num_rows_static = ops.get_static_value(self._num_rows) self._check_num_rows_possibly_add_asserts() self._num_rows_cast_to_dtype = _ops.cast(self._num_rows, self.dtype) self._num_rows_cast_to_real_dtype = _ops.cast( self._num_rows, dtypes.real_dtype(self.dtype))
def _solve_matmul_internal(self, x, solve_matmul_fn, adjoint=False, adjoint_arg=False): # We heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T # where vec stacks all the columns of the matrix under each other. # In our case, we use a variant of the lemma that is row-major # friendly: (A x B) * vec' X = vec' AXB^T # Where vec' reshapes a matrix into a vector. We can repeatedly apply this # for a collection of kronecker products. # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can # use the above to compute multiplications, solves with any composition of # transposes. output = x if adjoint_arg: if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) else: output = linalg.transpose(output) for o in reversed(self.operators): # Statically compute the reshape. if adjoint: operator_dimension = o.range_dimension_tensor() else: operator_dimension = o.domain_dimension_tensor() output_shape = _prefer_static_shape(output) if ops.get_static_value(operator_dimension) is not None: operator_dimension = ops.get_static_value(operator_dimension) if tensor_shape.TensorShape( output.shape )[-2] is not None and tensor_shape.TensorShape( output.shape)[-1] is not None: dim = int( tensor_shape.TensorShape(output.shape)[-2] * output_shape[-1] // operator_dimension) else: dim = _ops.cast(output_shape[-2] * output_shape[-1] // operator_dimension, dtype=dtypes.int32) output_shape = _prefer_static_concat_shape( output_shape[:-2], [dim, operator_dimension]) output = array_ops.reshape(output, shape=output_shape) # Conjugate because we are trying to compute A @ B^T, but # `LinearOperator` only supports `adjoint_arg`. if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) output = solve_matmul_fn(o, output, adjoint=adjoint, adjoint_arg=True) if adjoint_arg: col_dim = _prefer_static_shape(x)[-2] else: col_dim = _prefer_static_shape(x)[-1] if adjoint: row_dim = self.domain_dimension_tensor() else: row_dim = self.range_dimension_tensor() matrix_shape = [row_dim, col_dim] output = array_ops.reshape( output, _prefer_static_concat_shape( _prefer_static_shape(output)[:-2], matrix_shape)) if tensor_shape.TensorShape(x.shape).is_fully_defined(): if adjoint_arg: column_dim = tensor_shape.TensorShape(x.shape)[-2] else: column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output