示例#1
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = distribution_util.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = prefer_static.concat([
                    prefer_static.shape(x)[:-2],
                    [prefer_static.shape(x)[-1],
                     prefer_static.shape(x)[-2]]
                ],
                                                    axis=0)
            else:
                output_shape = prefer_static.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = prefer_static.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = prefer_static.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
示例#2
0
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     x = linalg.adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = distribution_util.with_dependencies([aps], x)
     return x * self._make_multiplier_matrix(conjugate=adjoint)
示例#3
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, rhs)
         rhs = distribution_util.with_dependencies([aps], rhs)
     return rhs / self._make_multiplier_matrix(conjugate=adjoint)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   # Note that adjoint has no effect since this matrix is self-adjoint.
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = distribution_util.with_dependencies([aps], x)
   return self._possibly_broadcast_batch_shape(x)