def _slice_single_param(param, param_ndims_to_matrix_ndims, slices, batch_shape): """Slices into the batch shape of a single parameter. Args: param: The original parameter to slice; either a `Tensor` or an object with batch shape (LinearOperator). param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for inferring matrix shape of the `LinearOperator`. For non-Tensor parameters, this is the number of this param's batch dimensions used by the matrix shape of the parent object. slices: iterable of slices received by `__getitem__`. batch_shape: The parameterized object's batch shape `Tensor`. Returns: new_param: Instance of the same type as `param`, batch-sliced according to `slices`. """ # Broadcast the parammeter to have full batch rank. param = _broadcast_parameter_with_batch_shape( param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape)) if hasattr(param, 'batch_shape_tensor'): param_batch_shape = param.batch_shape_tensor() else: param_batch_shape = prefer_static.shape(param) # Truncate by param_ndims_to_matrix_ndims param_batch_rank = array_ops.size(param_batch_shape) param_batch_shape = param_batch_shape[:(param_batch_rank - param_ndims_to_matrix_ndims)] # At this point the param should have full batch rank, *unless* it's an # atomic object like `tfb.Identity()` incapable of having any batch rank. if (ops.get_static_value(array_ops.size(batch_shape)) != 0 and ops.get_static_value(array_ops.size(param_batch_shape)) == 0): return param param_slices = _sanitize_slices(slices, intended_shape=batch_shape, deficient_shape=param_batch_shape) # Extend `param_slices` (which represents slicing into the # parameter's batch shape) with the parameter's event ndims. For example, if # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`. if param_ndims_to_matrix_ndims > 0: if Ellipsis not in [slc for slc in slices if not ops.is_tensor(slc)]: param_slices.append(Ellipsis) param_slices = param_slices + [slice(None) ] * param_ndims_to_matrix_ndims return param.__getitem__(tuple(param_slices))
def _tensor_rank_tensor(self, shape=None): # `shape` may be passed in if this can be pre-computed in a # more efficient manner, e.g. without excessive Tensor conversions. if self.tensor_rank is not None: return ops.convert_to_tensor(self.tensor_rank) else: shape = self.shape_tensor() if shape is None else shape return array_ops.size(shape)
def tensor_rank_tensor(self, name="tensor_rank_tensor"): """Rank (in the sense of tensors) of matrix corresponding to this operator. If this operator acts like the batch matrix `A` with `_ops.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `b + 2`. Args: name: A name for this `Op`. Returns: `int32` `Tensor`, determined at runtime. """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): # Prefer to use statically defined shape if available. if self.tensor_rank is not None: return ops.convert_to_tensor(self.tensor_rank) else: return array_ops.size(self.shape_tensor())