Ejemplo n.º 1
0
  def _unblockify_then_matricize(self, vec):
    """Flatten the block dimensions then reshape to a batch matrix."""
    # Suppose
    #   tensor_shape.TensorShape(vec.shape) = [v0, v1, v2, v3],
    #   self.block_depth = 2.
    # Then
    #   leading shape = [v0, v1]
    #   block shape = [v2, v3].
    # We will reshape vec to
    #   [v1, v2*v3, v0].

    # Un-blockify: Flatten block dimensions.  Reshape
    #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
    if tensor_shape.TensorShape(vec.shape).is_fully_defined():
      # vec_shape = [v0, v1, v2, v3]
      vec_shape = tensor_shape.TensorShape(vec.shape).as_list()
      # vec_leading_shape = [v0, v1]
      vec_leading_shape = vec_shape[:-self.block_depth]
      # vec_block_shape = [v2, v3]
      vec_block_shape = vec_shape[-self.block_depth:]
      # flat_shape = [v0, v1, v2*v3]
      flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
    else:
      vec_shape = array_ops.shape(vec)
      vec_leading_shape = vec_shape[:-self.block_depth]
      vec_block_shape = vec_shape[-self.block_depth:]
      flat_shape = array_ops.concat(
          (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
    vec_flat = array_ops.reshape(vec, flat_shape)

    # Matricize:  Reshape to batch matrix.
    #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
    # representing a shape [v1] batch of [v2*v3, v0] matrices.
    matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
    return matrix
Ejemplo n.º 2
0
 def _shape_tensor(self):
   # See _ops.TensorShape(self.shape) for explanation of steps
   s_shape = array_ops.shape(self._spectrum)
   batch_shape = s_shape[:-self.block_depth]
   trailing_dims = s_shape[-self.block_depth:]
   n = math_ops.reduce_prod(trailing_dims)
   n_x_n = [n, n]
   return array_ops.concat((batch_shape, n_x_n), 0)
Ejemplo n.º 3
0
 def _shape_tensor(self, spectrum=None):
     spectrum = self.spectrum if spectrum is None else spectrum
     # See tensor_shape.TensorShape(self.shape) for explanation of steps
     s_shape = prefer_static.shape(spectrum)
     batch_shape = s_shape[:-self.block_depth]
     trailing_dims = s_shape[-self.block_depth:]
     n = math_ops.reduce_prod(trailing_dims)
     n_x_n = [n, n]
     return prefer_static.concat((batch_shape, n_x_n), 0)
Ejemplo n.º 4
0
 def _determinant(self):
     axis = [-(i + 1) for i in range(self.block_depth)]
     det = math_ops.reduce_prod(self.spectrum, axis=axis)
     return _ops.cast(det, self.dtype)
 def _determinant(self):
   return math_ops.reduce_prod(self._get_diag(), axis=[-1])