def _args_to_matching_arrays(args_list, dtype_hint=None): """Converts a list to array using the first element for dtype. This method is used to match the behavior of `tf.concat`. Args: args_list: A list or tuple of arguments. dtype_hint: An optional hint used when converting the args to tensors. Returns: A list of tensors. """ dtype = None for arg in args_list: if ops.is_tensor(arg): dtype = arg.dtype break if dtype is None: ret = [] for arg in args_list: ret.append(ops.convert_to_tensor(arg, dtype, dtype_hint=dtype_hint)) if dtype is None: dtype = ret[-1].dtype else: ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list] return ret
def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, allow_python_preds): """Verifies input arguments for the case function. Args: pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a callable which returns a list of tensors. exclusive: True iff at most one predicate is allowed to evaluate to `True`. name: A name for the case operation. allow_python_preds: if true, pred_fn_pairs may contain Python bools in addition to boolean Tensors Raises: TypeError: If `pred_fn_pairs` is not a list/dictionary. TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. TypeError: If `fns[i]` is not callable for any i, or `default` is not callable. Returns: a tuple <list of scalar bool tensors, list of callables>. """ del name if not isinstance(pred_fn_pairs, (list, tuple, dict)): raise TypeError('fns must be a list, tuple, or dict') if isinstance(pred_fn_pairs, collections.OrderedDict): pred_fn_pairs = pred_fn_pairs.items() elif isinstance(pred_fn_pairs, dict): # No name to sort on in eager mode. Use dictionary traversal order, # which is nondeterministic in versions of Python < 3.6 if not exclusive: raise ValueError( 'Unordered dictionaries are not supported for the ' '`pred_fn_pairs` argument when `exclusive=False` and ' 'eager mode is enabled.') pred_fn_pairs = list(pred_fn_pairs.items()) for pred_fn_pair in pred_fn_pairs: if not isinstance(pred_fn_pair, tuple) or len(pred_fn_pair) != 2: raise TypeError('Each entry in pred_fn_pairs must be a 2-tuple') pred, fn = pred_fn_pair if ops.is_tensor(pred): if pred.dtype != dtype.bool: raise TypeError('pred must be Tensor of type bool: %s' % pred.name) elif not allow_python_preds: raise TypeError('pred must be a Tensor, got: %s' % pred) elif not isinstance(pred, bool): raise TypeError('pred must be a Tensor or bool, got: %s' % pred) if not callable(fn): raise TypeError('fn for pred %s must be callable.' % pred.name) predicates, actions = zip(*pred_fn_pairs) return predicates, actions
def _slice_single_param(param, param_ndims_to_matrix_ndims, slices, batch_shape): """Slices into the batch shape of a single parameter. Args: param: The original parameter to slice; either a `Tensor` or an object with batch shape (LinearOperator). param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for inferring matrix shape of the `LinearOperator`. For non-Tensor parameters, this is the number of this param's batch dimensions used by the matrix shape of the parent object. slices: iterable of slices received by `__getitem__`. batch_shape: The parameterized object's batch shape `Tensor`. Returns: new_param: Instance of the same type as `param`, batch-sliced according to `slices`. """ # Broadcast the parammeter to have full batch rank. param = _broadcast_parameter_with_batch_shape( param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape)) if hasattr(param, 'batch_shape_tensor'): param_batch_shape = param.batch_shape_tensor() else: param_batch_shape = prefer_static.shape(param) # Truncate by param_ndims_to_matrix_ndims param_batch_rank = array_ops.size(param_batch_shape) param_batch_shape = param_batch_shape[:(param_batch_rank - param_ndims_to_matrix_ndims)] # At this point the param should have full batch rank, *unless* it's an # atomic object like `tfb.Identity()` incapable of having any batch rank. if (ops.get_static_value(array_ops.size(batch_shape)) != 0 and ops.get_static_value(array_ops.size(param_batch_shape)) == 0): return param param_slices = _sanitize_slices(slices, intended_shape=batch_shape, deficient_shape=param_batch_shape) # Extend `param_slices` (which represents slicing into the # parameter's batch shape) with the parameter's event ndims. For example, if # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`. if param_ndims_to_matrix_ndims > 0: if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]: param_slices.append(Ellipsis) param_slices = param_slices + [slice(None) ] * param_ndims_to_matrix_ndims return param.__getitem__(tuple(param_slices))
def _set_graph_parents(self, graph_parents): """Set self._graph_parents. Called during derived class init. This method allows derived classes to set graph_parents, without triggering a deprecation warning (which is invoked if `graph_parents` is passed during `__init__`. Args: graph_parents: Iterable over Tensors. """ # TODO(b/143910018) Remove this function in V3. graph_parents = [] if graph_parents is None else graph_parents for i, t in enumerate(graph_parents): if t is None or not (linear_operator_util.is_ref(t) or ops.is_tensor(t)): raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) self._graph_parents = graph_parents
def __init__(self, dtype, graph_parents=None, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, is_square=None, name=None): r"""Initialize the `LinearOperator`. **This is a private method for subclass use.** **Subclasses should copy-paste this `__init__` documentation.** Args: dtype: The type of the this `LinearOperator`. Arguments to `matmul` and `solve` will have to be this type. graph_parents: Python list of graph prerequisites of this `LinearOperator` Typically tensors that are passed during initialization. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `dtype` is real, this is equivalent to being symmetric. is_positive_definite: Expect that this operator is positive definite, meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. name: A name for this `LinearOperator`. Raises: ValueError: If any member of graph_parents is `None` or not a `Tensor`. ValueError: If hints are set incorrectly. """ # Check and auto-set flags. if is_positive_definite: if is_non_singular is False: raise ValueError( "A positive definite matrix is always non-singular.") is_non_singular = True if is_non_singular: if is_square is False: raise ValueError("A non-singular matrix is always square.") is_square = True if is_self_adjoint: if is_square is False: raise ValueError("A self-adjoint matrix is always square.") is_square = True self._is_square_set_or_implied_by_hints = is_square graph_parents = [] if graph_parents is None else graph_parents for i, t in enumerate(graph_parents): if t is None or not ops.is_tensor(t): raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) self._dtype = dtypes.as_dtype(dtype) if dtype else dtype self._graph_parents = graph_parents self._is_non_singular = is_non_singular self._is_self_adjoint = is_self_adjoint self._is_positive_definite = is_positive_definite self._name = name or type(self).__name__
assert_negative = utils.copy_docstring('tf.debugging.assert_negative', _assert_negative) assert_non_negative = utils.copy_docstring('tf.debugging.assert_non_negative', _assert_non_negative) assert_non_positive = utils.copy_docstring('tf.debugging.assert_non_positive', _assert_non_positive) assert_none_equal = utils.copy_docstring('tf.debugging.assert_none_equal', _assert_none_equal) assert_positive = utils.copy_docstring('tf.debugging.assert_positive', _assert_positive) assert_proper_iterable = utils.copy_docstring( 'tf.debugging.assert_proper_iterable', _assert_proper_iterable) assert_rank_at_least = utils.copy_docstring( 'tf.debugging.assert_rank_at_least', _assert_rank_at_least) assert_rank_in = utils.copy_docstring('tf.debugging.assert_rank_in', _assert_rank_in) check_numerics = utils.copy_docstring('tf.debugging.check_numerics', lambda x, *_, **__: x) is_numeric_tensor = utils.copy_docstring( 'tf.debugging.is_numeric_tensor', lambda x: ops.is_tensor(x) and np.issubdtype(x.dtype, np.number))