def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1): """Split `x` into blocks matching `operators`'s `domain_dimension`. Specifically, if we have a blockwise lower-triangular matrix, with block sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg` on `axis` into `J` tensors, whose shape at `axis` is `M_j`. Args: block_dims: Iterable of `TensorShapes`. block_dims_fn: Callable returning an iterable of `Tensor`s. arg: `Tensor`. `arg` is split into `J` tensors. axis: Python `Integer` representing the axis to split `arg` on. Returns: A list of `Tensor`s. """ block_sizes = [dim.value for dim in block_dims] if any(d is None for d in block_sizes): block_sizes = block_dims_fn() return array_ops.split(arg, block_sizes, axis=axis)
def _split_input_into_blocks(self, x, axis=-1): """Split `x` into blocks matching `operators`'s `domain_dimension`. Specifically, if we have a block diagonal matrix, with block sizes `[M_j, M_j] j = 1..J`, this method splits `x` on `axis` into `J` tensors, whose shape at `axis` is `M_j`. Args: x: `Tensor`. `x` is split into `J` tensors. axis: Python `Integer` representing the axis to split `x` on. Returns: A list of `Tensor`s. """ block_sizes = [] if _ops.TensorShape(self.shape).is_fully_defined(): for operator in self.operators: block_sizes += [operator.domain_dimension.value] else: for operator in self.operators: block_sizes += [operator.domain_dimension_tensor()] return array_ops.split(x, block_sizes, axis=axis)