예제 #1
0
 def _shape_tensor(self):
     batch_shape = array_ops.broadcast_dynamic_shape(
         self.base_operator.batch_shape_tensor(),
         array_ops.shape(self.u)[:-2])
     batch_shape = array_ops.broadcast_dynamic_shape(
         batch_shape,
         array_ops.shape(self.v)[:-2])
     return array_ops.concat(
         [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
 def _shape_tensor(self):
     batch_shape = array_ops.broadcast_dynamic_shape(
         self.base_operator.batch_shape_tensor(),
         self.diag_operator.batch_shape_tensor())
     batch_shape = array_ops.broadcast_dynamic_shape(
         batch_shape,
         prefer_static.shape(self.u)[:-2])
     batch_shape = array_ops.broadcast_dynamic_shape(
         batch_shape,
         prefer_static.shape(self.v)[:-2])
     return prefer_static.concat(
         [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
 def _to_dense(self):
     row = ops.convert_to_tensor(self.row)
     col = ops.convert_to_tensor(self.col)
     total_shape = array_ops.broadcast_dynamic_shape(
         array_ops.shape(row), array_ops.shape(col))
     n = array_ops.shape(row)[-1]
     row = _ops.broadcast_to(row, total_shape)
     col = _ops.broadcast_to(col, total_shape)
     # We concatenate the column in reverse order to the row.
     # This gives us 2*n + 1 elements.
     elements = array_ops.concat(
         [array_ops.reverse(col, axis=[-1]), row[..., 1:]], axis=-1)
     # Given the above vector, the i-th row of the Toeplitz matrix
     # is the last n elements of the above vector shifted i right
     # (hence the first row is just the row vector provided, and
     # the first element of each row will belong to the column vector).
     # We construct these set of indices below.
     indices = math_ops.mod(
         # How much to shift right. This corresponds to `i`.
         math_ops.range(0, n) +
         # Specifies the last `n` indices.
         math_ops.range(n - 1, -1, -1)[..., _ops.newaxis],
         # Mod out by the total number of elements to ensure the index is
         # non-negative (for tf.gather) and < 2 * n - 1.
         2 * n - 1)
     return array_ops.gather(elements, indices, axis=-1)
 def _shape_tensor(self, row=None, col=None):
     row = self.row if row is None else row
     col = self.col if col is None else col
     v_shape = array_ops.broadcast_dynamic_shape(array_ops.shape(row),
                                                 array_ops.shape(col))
     k = v_shape[-1]
     return array_ops.concat((v_shape, [k]), 0)
예제 #5
0
def _broadcast_parameter_with_batch_shape(param, param_ndims_to_matrix_ndims,
                                          batch_shape):
    """Broadcasts `param` with the given batch shape, recursively."""
    if hasattr(param, 'batch_shape_tensor'):
        # Recursively broadcast every parameter inside the operator.
        override_dict = {}
        for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(
        ):  # pylint:disable=protected-access,line-too-long
            sub_param = getattr(param, name)
            override_dict[name] = nest.map_structure_up_to(
                sub_param,
                functools.partial(_broadcast_parameter_with_batch_shape,
                                  batch_shape=batch_shape), sub_param, ndims)
        parameters = dict(param.parameters, **override_dict)
        return type(param)(**parameters)

    base_shape = prefer_static.concat([
        batch_shape,
        array_ops.ones([param_ndims_to_matrix_ndims], dtype=dtypes.int32)
    ],
                                      axis=0)
    return _ops.broadcast_to(
        param,
        array_ops.broadcast_dynamic_shape(base_shape,
                                          prefer_static.shape(param)))
  def _shape_tensor(self):
    # Avoid messy broadcasting if possible.
    if tensor_shape.TensorShape(self.shape).is_fully_defined():
      return ops.convert_to_tensor(
          tensor_shape.TensorShape(self.shape).as_list(), dtype=dtypes.int32, name="shape")

    domain_dimension = sum(self._block_domain_dimension_tensors())
    range_dimension = sum(self._block_range_dimension_tensors())
    matrix_shape = array_ops.stack([domain_dimension, range_dimension])

    batch_shape = self.operators[0][0].batch_shape_tensor()
    for row in self.operators[1:]:
      for operator in row:
        batch_shape = array_ops.broadcast_dynamic_shape(
            batch_shape, operator.batch_shape_tensor())

    return prefer_static.concat((batch_shape, matrix_shape), 0)
    def _shape_tensor(self):
        domain_dimension = self.operators[0].domain_dimension_tensor()
        for operator in self.operators[1:]:
            domain_dimension *= operator.domain_dimension_tensor()

        range_dimension = self.operators[0].range_dimension_tensor()
        for operator in self.operators[1:]:
            range_dimension *= operator.range_dimension_tensor()

        matrix_shape = [range_dimension, domain_dimension]

        # Get broadcast batch shape.
        # broadcast_shape checks for compatibility.
        batch_shape = self.operators[0].batch_shape_tensor()
        for operator in self.operators[1:]:
            batch_shape = array_ops.broadcast_dynamic_shape(
                batch_shape, operator.batch_shape_tensor())

        return array_ops.concat((batch_shape, matrix_shape), 0)
예제 #8
0
def broadcast_matrix_batch_dims(batch_matrices, name=None):
  """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random.normal(shape=(2, 3, 1, 4, 4))
  y = tf.random.normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  tensor_shape.TensorShape(x_bc.shape)
  ==> (2, 3, 2, 4, 4)

  tensor_shape.TensorShape(y_bc.shape)
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
  with ops.name_scope(
      name or "broadcast_matrix_batch_dims", values=batch_matrices):
    check_ops.assert_proper_iterable(batch_matrices)
    batch_matrices = list(batch_matrices)

    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = ops.convert_to_tensor(mat)
      assert_is_batch_matrix(batch_matrices[i])

    if len(batch_matrices) < 2:
      return batch_matrices

    # Try static broadcasting.
    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
    # E.g. if batch_matrices = [x, y], with
    # tensor_shape.TensorShape(x.shape) =    [2, j, k]  (batch shape =    [2])
    # tensor_shape.TensorShape(y.shape) = [3, 1, l, m]  (batch shape = [3, 1])
    # ==> bcast_batch_shape = [3, 2]
    bcast_batch_shape = tensor_shape.TensorShape(batch_matrices[0].shape)[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = _ops.broadcast_static_shape(
          bcast_batch_shape,
          tensor_shape.TensorShape(mat.shape)[:-2])
    if bcast_batch_shape.is_fully_defined():
      for i, mat in enumerate(batch_matrices):
        if tensor_shape.TensorShape(mat.shape)[:-2] != bcast_batch_shape:
          bcast_shape = array_ops.concat(
              [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
          batch_matrices[i] = _ops.broadcast_to(mat, bcast_shape)
      return batch_matrices

    # Since static didn't work, do dynamic, which always copies data.
    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
          bcast_batch_shape,
          array_ops.shape(mat)[:-2])
    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = _ops.broadcast_to(
          mat,
          array_ops.concat(
              [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0))

    return batch_matrices
예제 #9
0
 def _shape_tensor(self):
     v_shape = array_ops.broadcast_dynamic_shape(array_ops.shape(self.row),
                                                 array_ops.shape(self.col))
     k = v_shape[-1]
     return array_ops.concat((v_shape, [k]), 0)