Beispiel #1
0
def _args_to_matching_arrays(args_list, dtype_hint=None):
  """Converts a list to array using the first element for dtype.

  This method is used to match the behavior of `tf.concat`.

  Args:
    args_list: A list or tuple of arguments.
    dtype_hint: An optional hint used when converting the args to tensors.
  Returns:
    A list of tensors.
  """
  dtype = None
  for arg in args_list:
    if ops.is_tensor(arg):
      dtype = arg.dtype
      break
  if dtype is None:
    ret = []
    for arg in args_list:
      ret.append(ops.convert_to_tensor(arg, dtype, dtype_hint=dtype_hint))
      if dtype is None:
        dtype = ret[-1].dtype
  else:
    ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
  return ret
Beispiel #2
0
def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
                                       allow_python_preds):
    """Verifies input arguments for the case function.

  Args:
    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
      callable which returns a list of tensors.
    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
    name: A name for the case operation.
    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
      addition to boolean Tensors
  Raises:
    TypeError: If `pred_fn_pairs` is not a list/dictionary.
    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
    TypeError: If `fns[i]` is not callable for any i, or `default` is not
               callable.
  Returns:
    a tuple <list of scalar bool tensors, list of callables>.
  """
    del name
    if not isinstance(pred_fn_pairs, (list, tuple, dict)):
        raise TypeError('fns must be a list, tuple, or dict')

    if isinstance(pred_fn_pairs, collections.OrderedDict):
        pred_fn_pairs = pred_fn_pairs.items()
    elif isinstance(pred_fn_pairs, dict):
        # No name to sort on in eager mode. Use dictionary traversal order,
        # which is nondeterministic in versions of Python < 3.6
        if not exclusive:
            raise ValueError(
                'Unordered dictionaries are not supported for the '
                '`pred_fn_pairs` argument when `exclusive=False` and '
                'eager mode is enabled.')
        pred_fn_pairs = list(pred_fn_pairs.items())
    for pred_fn_pair in pred_fn_pairs:
        if not isinstance(pred_fn_pair, tuple) or len(pred_fn_pair) != 2:
            raise TypeError('Each entry in pred_fn_pairs must be a 2-tuple')
        pred, fn = pred_fn_pair

        if ops.is_tensor(pred):
            if pred.dtype != dtype.bool:
                raise TypeError('pred must be Tensor of type bool: %s' %
                                pred.name)
        elif not allow_python_preds:
            raise TypeError('pred must be a Tensor, got: %s' % pred)
        elif not isinstance(pred, bool):
            raise TypeError('pred must be a Tensor or bool, got: %s' % pred)

        if not callable(fn):
            raise TypeError('fn for pred %s must be callable.' % pred.name)

    predicates, actions = zip(*pred_fn_pairs)
    return predicates, actions
Beispiel #3
0
def _slice_single_param(param, param_ndims_to_matrix_ndims, slices,
                        batch_shape):
    """Slices into the batch shape of a single parameter.

  Args:
    param: The original parameter to slice; either a `Tensor` or an object
      with batch shape (LinearOperator).
    param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
      inferring matrix shape of the `LinearOperator`. For non-Tensor
      parameters, this is the number of this param's batch dimensions used by
      the matrix shape of the parent object.
    slices: iterable of slices received by `__getitem__`.
    batch_shape: The parameterized object's batch shape `Tensor`.

  Returns:
    new_param: Instance of the same type as `param`, batch-sliced according to
      `slices`.
  """
    # Broadcast the parammeter to have full batch rank.
    param = _broadcast_parameter_with_batch_shape(
        param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))

    if hasattr(param, 'batch_shape_tensor'):
        param_batch_shape = param.batch_shape_tensor()
    else:
        param_batch_shape = prefer_static.shape(param)
    # Truncate by param_ndims_to_matrix_ndims
    param_batch_rank = array_ops.size(param_batch_shape)
    param_batch_shape = param_batch_shape[:(param_batch_rank -
                                            param_ndims_to_matrix_ndims)]

    # At this point the param should have full batch rank, *unless* it's an
    # atomic object like `tfb.Identity()` incapable of having any batch rank.
    if (ops.get_static_value(array_ops.size(batch_shape)) != 0
            and ops.get_static_value(array_ops.size(param_batch_shape)) == 0):
        return param
    param_slices = _sanitize_slices(slices,
                                    intended_shape=batch_shape,
                                    deficient_shape=param_batch_shape)

    # Extend `param_slices` (which represents slicing into the
    # parameter's batch shape) with the parameter's event ndims. For example, if
    # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
    if param_ndims_to_matrix_ndims > 0:
        if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]:
            param_slices.append(Ellipsis)
        param_slices = param_slices + [slice(None)
                                       ] * param_ndims_to_matrix_ndims
    return param.__getitem__(tuple(param_slices))
Beispiel #4
0
  def _set_graph_parents(self, graph_parents):
    """Set self._graph_parents.  Called during derived class init.

    This method allows derived classes to set graph_parents, without triggering
    a deprecation warning (which is invoked if `graph_parents` is passed during
    `__init__`.

    Args:
      graph_parents: Iterable over Tensors.
    """
    # TODO(b/143910018) Remove this function in V3.
    graph_parents = [] if graph_parents is None else graph_parents
    for i, t in enumerate(graph_parents):
      if t is None or not (linear_operator_util.is_ref(t) or
                           ops.is_tensor(t)):
        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
    self._graph_parents = graph_parents
    def __init__(self,
                 dtype,
                 graph_parents=None,
                 is_non_singular=None,
                 is_self_adjoint=None,
                 is_positive_definite=None,
                 is_square=None,
                 name=None):
        r"""Initialize the `LinearOperator`.

    **This is a private method for subclass use.**
    **Subclasses should copy-paste this `__init__` documentation.**

    Args:
      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
        `solve` will have to be this type.
      graph_parents: Python list of graph prerequisites of this `LinearOperator`
        Typically tensors that are passed during initialization.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.  If `dtype` is real, this is equivalent to being symmetric.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.

    Raises:
      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
      ValueError:  If hints are set incorrectly.
    """
        # Check and auto-set flags.
        if is_positive_definite:
            if is_non_singular is False:
                raise ValueError(
                    "A positive definite matrix is always non-singular.")
            is_non_singular = True

        if is_non_singular:
            if is_square is False:
                raise ValueError("A non-singular matrix is always square.")
            is_square = True

        if is_self_adjoint:
            if is_square is False:
                raise ValueError("A self-adjoint matrix is always square.")
            is_square = True

        self._is_square_set_or_implied_by_hints = is_square

        graph_parents = [] if graph_parents is None else graph_parents
        for i, t in enumerate(graph_parents):
            if t is None or not ops.is_tensor(t):
                raise ValueError("Graph parent item %d is not a Tensor; %s." %
                                 (i, t))
        self._dtype = dtypes.as_dtype(dtype) if dtype else dtype
        self._graph_parents = graph_parents
        self._is_non_singular = is_non_singular
        self._is_self_adjoint = is_self_adjoint
        self._is_positive_definite = is_positive_definite
        self._name = name or type(self).__name__
Beispiel #6
0
assert_negative = utils.copy_docstring('tf.debugging.assert_negative',
                                       _assert_negative)

assert_non_negative = utils.copy_docstring('tf.debugging.assert_non_negative',
                                           _assert_non_negative)

assert_non_positive = utils.copy_docstring('tf.debugging.assert_non_positive',
                                           _assert_non_positive)

assert_none_equal = utils.copy_docstring('tf.debugging.assert_none_equal',
                                         _assert_none_equal)

assert_positive = utils.copy_docstring('tf.debugging.assert_positive',
                                       _assert_positive)

assert_proper_iterable = utils.copy_docstring(
    'tf.debugging.assert_proper_iterable', _assert_proper_iterable)

assert_rank_at_least = utils.copy_docstring(
    'tf.debugging.assert_rank_at_least', _assert_rank_at_least)

assert_rank_in = utils.copy_docstring('tf.debugging.assert_rank_in',
                                      _assert_rank_in)

check_numerics = utils.copy_docstring('tf.debugging.check_numerics',
                                      lambda x, *_, **__: x)

is_numeric_tensor = utils.copy_docstring(
    'tf.debugging.is_numeric_tensor',
    lambda x: ops.is_tensor(x) and np.issubdtype(x.dtype, np.number))