示例#1
0
def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1):
  """Split `x` into blocks matching `operators`'s `domain_dimension`.

  Specifically, if we have a blockwise lower-triangular matrix, with block
  sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`,  this method splits `arg`
  on `axis` into `J` tensors, whose shape at `axis` is `M_j`.

  Args:
    block_dims: Iterable of `TensorShapes`.
    block_dims_fn: Callable returning an iterable of `Tensor`s.
    arg: `Tensor`. `arg` is split into `J` tensors.
    axis: Python `Integer` representing the axis to split `arg` on.

  Returns:
    A list of `Tensor`s.
  """
  block_sizes = [dim.value for dim in block_dims]
  if any(d is None for d in block_sizes):
    block_sizes = block_dims_fn()
  return array_ops.split(arg, block_sizes, axis=axis)
    def _split_input_into_blocks(self, x, axis=-1):
        """Split `x` into blocks matching `operators`'s `domain_dimension`.

    Specifically, if we have a block diagonal matrix, with block sizes
    `[M_j, M_j] j = 1..J`,  this method splits `x` on `axis` into `J`
    tensors, whose shape at `axis` is `M_j`.

    Args:
      x: `Tensor`. `x` is split into `J` tensors.
      axis: Python `Integer` representing the axis to split `x` on.

    Returns:
      A list of `Tensor`s.
    """
        block_sizes = []
        if _ops.TensorShape(self.shape).is_fully_defined():
            for operator in self.operators:
                block_sizes += [operator.domain_dimension.value]
        else:
            for operator in self.operators:
                block_sizes += [operator.domain_dimension_tensor()]

        return array_ops.split(x, block_sizes, axis=axis)