Exemplo n.º 1
0
  def _trace(self):
    # The diagonal of the [[nested] block] circulant operator is the mean of
    # the spectrum.
    # Proof:  For the [0,...,0] element, this follows from the IDFT formula.
    # Then the result follows since all diagonal elements are the same.

    # Therefore, the trace is the sum of the spectrum.

    # Get shape of diag along with the axis over which to reduce the spectrum.
    # We will reduce the spectrum over all block indices.
    if tensor_shape.TensorShape(self.spectrum.shape).is_fully_defined():
      spec_rank = tensor_shape.TensorShape(self.spectrum.shape).ndims
      axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32)
    else:
      spec_rank = array_ops.rank(self.spectrum)
      axis = math_ops.range(spec_rank - self.block_depth, spec_rank)

    # Real diag part "re_d".
    # Suppose tensor_shape.TensorShape(spectrum.shape) = [B1,...,Bb, N1, N2]
    # tensor_shape.TensorShape(self.shape) = [B1,...,Bb, N, N], with N1 * N2 = N.
    # tensor_shape.TensorShape(re_d_value.shape) = [B1,...,Bb]
    re_d_value = math_ops.reduce_sum(math_ops.real(self.spectrum), axis=axis)

    if not np.issubdtype(self.dtype, np.complexfloating):
      return _ops.cast(re_d_value, self.dtype)

    # Imaginary part, "im_d".
    if self.is_self_adjoint:
      im_d_value = array_ops.zeros_like(re_d_value)
    else:
      im_d_value = math_ops.reduce_sum(math_ops.imag(self.spectrum), axis=axis)

    return _ops.cast(math_ops.complex(re_d_value, im_d_value), self.dtype)
Exemplo n.º 2
0
 def _log_abs_determinant(self):
     logging.warn(
         "Using (possibly slow) default implementation of determinant."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     if self._can_use_cholesky():
         diag = _linalg.diag_part(linalg_ops.cholesky(self.to_dense()))
         return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
     _, log_abs_det = linalg.slogdet(self.to_dense())
     return log_abs_det
Exemplo n.º 3
0
 def _diag_part(self):
     # [U D V^T]_{ii} = sum_{jk} U_{ij} D_{jk} V_{ik}
     #                = sum_{j}  U_{ij} D_{jj} V_{ij}
     product = self.u * math_ops.conj(self.v)
     if self.diag_update is not None:
         product = product * array_ops.expand_dims(self.diag_update,
                                                   axis=-2)
     return (math_ops.reduce_sum(product, axis=-1) +
             self.base_operator.diag_part())
Exemplo n.º 4
0
    def _log_abs_determinant(self):
        # Recall
        #   det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L)
        #                  = det(C) det(D) det(L)
        log_abs_det_d = self.diag_operator.log_abs_determinant()
        log_abs_det_l = self.base_operator.log_abs_determinant()

        if self._use_cholesky:
            chol_cap_diag = _linalg.diag_part(self._chol_capacitance)
            log_abs_det_c = 2 * math_ops.reduce_sum(
                math_ops.log(chol_cap_diag), axis=[-1])
        else:
            det_c = linalg_ops.matrix_determinant(self._capacitance)
            log_abs_det_c = math_ops.log(math_ops.abs(det_c))
            if np.issubdtype(self.dtype, np.complexfloating):
                log_abs_det_c = _ops.cast(log_abs_det_c, dtype=self.dtype)

        return log_abs_det_c + log_abs_det_d + log_abs_det_l
Exemplo n.º 5
0
 def _trace(self):
     return math_ops.reduce_sum(self.diag_part(), axis=-1)
Exemplo n.º 6
0
 def _log_abs_determinant(self):
     axis = [-(i + 1) for i in range(self.block_depth)]
     lad = math_ops.reduce_sum(math_ops.log(math_ops.abs(self.spectrum)),
                               axis=axis)
     return _ops.cast(lad, self.dtype)
 def _log_abs_determinant(self):
   return math_ops.reduce_sum(
       math_ops.log(math_ops.abs(self._get_diag())), axis=[-1])
Exemplo n.º 8
0
 def _log_abs_determinant(self):
     log_det = math_ops.reduce_sum(math_ops.log(math_ops.abs(self._diag)),
                                   axis=[-1])
     if np.issubdtype(self.dtype, np.complexfloating):
         log_det = _ops.cast(log_det, dtype=self.dtype)
     return log_det