Пример #1
0
 def __init__(self,
              diag,
              dims=None,
              dir=0,
              device='cpu',
              togpu=(False, False),
              tocpu=(False, False),
              dtype=torch.float32):
     if not isinstance(diag, (torch.Tensor, ComplexTensor)):
         self.complex = True if np.iscomplexobj(diag) else False
         self.diag = \
             complextorch_fromnumpy(diag.flatten()) if self.complex \
                 else torch.from_numpy(diag.flatten())
     else:
         self.complex = True if isinstance(diag, ComplexTensor) else False
         self.diag = flatten(diag) if self.complex else diag.flatten()
     if dims is None:
         self.shape = (len(self.diag), len(self.diag))
         self.dims = None
         self.reshape = False
     else:
         diagdims = [1] * len(dims)
         diagdims[dir] = dims[dir]
         self.diag = reshape(self.diag, diagdims) if self.complex \
             else self.diag.reshape(diagdims)
         self.shape = (np.prod(dims), np.prod(dims))
         self.dims = dims
         self.reshape = True
     self.device = device
     self.togpu = togpu
     self.tocpu = tocpu
     self.dtype = dtype
     self.explicit = False
     self.Op = None
Пример #2
0
 def _matvec(self, x):
     if not self.reshape:
         y = self.diag * x
     else:
         if self.complex:
             x = reshape(x, self.dims)
             y = flatten(self.diag * x)
         else:
             x = x.reshape(self.dims)
             y = (self.diag * x).view(-1)
     return y
Пример #3
0
 def _rmatvec(self, x):
     if self.complex:
         diagadj = conj(self.diag)
     else:
         diagadj = self.diag
     if not self.reshape:
         y = diagadj * x
     else:
         if self.complex:
             x = reshape(x, self.dims)
             y = flatten(diagadj * x)
         else:
             x = x.reshape(self.dims)
             y = (diagadj * x).view(-1)
     return y
Пример #4
0
 def _matvec(self, x):
     if self.reshape:
         x = reshape(x, self.newshape[0]) if self.complex else \
             torch.reshape(x, self.newshape[0])
     else:
         if self.complex:
             x = x.t()
     if self.complex:
         y = self.A.mm(x)
         if not self.reshape:
             y = y.t()
     else:
         y = self.A.matmul(x)
     if self.reshape:
         y = flatten(y) if self.complex else y.view(-1)
     return y