Exemple #1
0
def _slice_single_param(param, param_ndims_to_matrix_ndims, slices,
                        batch_shape):
    """Slices into the batch shape of a single parameter.

  Args:
    param: The original parameter to slice; either a `Tensor` or an object
      with batch shape (LinearOperator).
    param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
      inferring matrix shape of the `LinearOperator`. For non-Tensor
      parameters, this is the number of this param's batch dimensions used by
      the matrix shape of the parent object.
    slices: iterable of slices received by `__getitem__`.
    batch_shape: The parameterized object's batch shape `Tensor`.

  Returns:
    new_param: Instance of the same type as `param`, batch-sliced according to
      `slices`.
  """
    # Broadcast the parammeter to have full batch rank.
    param = _broadcast_parameter_with_batch_shape(
        param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))

    if hasattr(param, 'batch_shape_tensor'):
        param_batch_shape = param.batch_shape_tensor()
    else:
        param_batch_shape = prefer_static.shape(param)
    # Truncate by param_ndims_to_matrix_ndims
    param_batch_rank = array_ops.size(param_batch_shape)
    param_batch_shape = param_batch_shape[:(param_batch_rank -
                                            param_ndims_to_matrix_ndims)]

    # At this point the param should have full batch rank, *unless* it's an
    # atomic object like `tfb.Identity()` incapable of having any batch rank.
    if (ops.get_static_value(array_ops.size(batch_shape)) != 0
            and ops.get_static_value(array_ops.size(param_batch_shape)) == 0):
        return param
    param_slices = _sanitize_slices(slices,
                                    intended_shape=batch_shape,
                                    deficient_shape=param_batch_shape)

    # Extend `param_slices` (which represents slicing into the
    # parameter's batch shape) with the parameter's event ndims. For example, if
    # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
    if param_ndims_to_matrix_ndims > 0:
        if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]:
            param_slices.append(Ellipsis)
        param_slices = param_slices + [slice(None)
                                       ] * param_ndims_to_matrix_ndims
    return param.__getitem__(tuple(param_slices))
Exemple #2
0
 def _tensor_rank_tensor(self, shape=None):
     # `shape` may be passed in if this can be pre-computed in a
     # more efficient manner, e.g. without excessive Tensor conversions.
     if self.tensor_rank is not None:
         return ops.convert_to_tensor(self.tensor_rank)
     else:
         shape = self.shape_tensor() if shape is None else shape
         return array_ops.size(shape)
    def tensor_rank_tensor(self, name="tensor_rank_tensor"):
        """Rank (in the sense of tensors) of matrix corresponding to this operator.

    If this operator acts like the batch matrix `A` with
    `_ops.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `b + 2`.

    Args:
      name:  A name for this `Op`.

    Returns:
      `int32` `Tensor`, determined at runtime.
    """
        # Derived classes get this "for free" once .shape() is implemented.
        with self._name_scope(name):
            # Prefer to use statically defined shape if available.
            if self.tensor_rank is not None:
                return ops.convert_to_tensor(self.tensor_rank)
            else:
                return array_ops.size(self.shape_tensor())