コード例 #1
0
ファイル: jsense.py プロジェクト: jkkronk/ddpPytorch
def sense_estimation_ls(y, X, basis_funct, uspat):
    """
    We estimate the bias field with a polynomial basis of the given order, using least squares method
    :param data: data (Fourier) [n x m, c]
    :param x_estimate: predicted reconstruction estimate [n x m, c] (needed to compute recon_error!)
    :param max_basis_order:
    :param ls_threshold:
    :param max_bias_eval:
    :return:
    """
    num_coils, sizex, sizey = y.shape
    num_coeffs = basis_funct.shape[1]

    coeff_coils = torch.zeros((num_coils, num_coeffs),
                              dtype=torch.cfloat,
                              device=y.real.device)
    # XA - Y = 0
    for i in range(num_coils):
        Y = y[i, :, :].reshape(sizex * sizey)
        A = UFT(X, uspat,
                basis_funct[i, :, :, :]).reshape(num_coeffs, sizex * sizey)
        coeff = torch.matmul(
            torch.matmul(Y, torch.transpose(torch.conj(A), 0, 1)),
            complex_inverse(
                torch.matmul(A, torch.transpose(torch.conj(A), 0, 1))))
        coeff_coils[i, :] = coeff.clone()
        del Y
        del A
        del coeff

    return coeff_coils
コード例 #2
0
    def xstep(self):
        r"""Minimise Augmented Lagrangian with respect to
      :math:`\mathbf{x}`."""

        self.YU[:] = self.Y - self.U

        b = self.DSf + self.rho * torch.fft.rfftn(self.YU, **self.fftopt)
        if self.cri.Cd == 1:
            self.Xf[:] = solvedbi_sm(self.Df, self.rho, b, self.c,
                                     self.cri.axisM)
        else:
            self.Xf[:] = solvemdbi_ism(self.Df, self.rho, b, self.cri.axisM,
                                       self.cri.axisC)

        self.X = torch.fft.irfftn(self.Xf, **self.fftopt)

        if self.opt['LinSolveCheck']:
            Dop = lambda x: torch.sum(self.Df * x, dim=self.cri.axisM)
            if self.cri.Cd == 1:
                DHop = lambda x: torch.conj(self.Df) * x
            else:
                DHop = lambda x: torch.sum(torch.conj(self.Df) * x,
                                           dim=self.cri.axisC)
            ax = DHop(Dop(self.Xf)) + self.rho * self.Xf
            self.xrrs = rrs(ax, b)
        else:
            self.xrrs = None
コード例 #3
0
    def forward(self, input, angle):
        
        # padding
        mag, ph, real, image= self.stft.transform(input.reshape(-1, input.size()[-1]))
        pad = Variable(torch.zeros(mag.size()[0],mag.size()[1], 1)).type(input.type())
        mag = torch.cat([mag, pad], -1)
        ph = torch.cat([ph, pad], -1)
        output, rest = self.pad_signal(input)
        enc_output = self.encoder(output[:, :1])  # B, N, L
        mag = mag.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1))
        ph = ph.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1))
        LPS = 10 * torch.log10(mag ** 2 + 10e-20)

        complex = (mag * torch.exp(ph * 1j))
        IPD_list = []
        for m in self.pairs:
            com_u1 = complex[:, m[0]]
            com_u2 = complex[:, m[1]]
            IPD = torch.angle(com_u1 * torch.conj(com_u2))
            IPD /= (self.frequency_vector + 1.0)[:, None]
            IPD = IPD % torch.pi
            IPD = IPD.unsqueeze(dim=1)
            IPD_list.append(IPD)
        IPD = torch.cat(IPD_list, dim=1)
        steering_vector = self.__get_steering_vector(angle, self.pairs)
        steering_vector = steering_vector.unsqueeze(dim=-1)
        AF = steering_vector * IPD
        AF = AF/AF.sum(dim=1, keepdims=True).real
        w = self.w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1, -1)
        dpr = torch.zeros((AF.size(0), self.n_grid, AF.size(-2), AF.size(-1)), dtype=torch.complex128)
        print(w.size())
        print(complex.size())
        exit()
        for i in range(36):
            for j in range(602):
                for h in range(97):
                    dpr[:, i, h, j] = (w[:, :, i, h] * complex[:, :, h, j]).sum(dim=1)
        dpr = (dpr * torch.conj(dpr))/ torch.sum(dpr * torch.conj(dpr), dim=1, keepdim=True)
        print(dpr.size())
        print(AF.size())

        feature_list = [enc_output.unsqueeze(dim=1), AF, dpr, torch.cos(IPD)]
        fusion = torch.cat(feature_list, dim=1).float()

        batch_size = output.size(0)
        fusion = fusion.view(batch_size, -1, fusion.size()[-1])
        
        # waveform encoder


        masks = torch.sigmoid(self.TCN(fusion)).view(batch_size, self.num_spk, self.enc_dim, -1)  # B, C, N, L
        masked_output = enc_output.unsqueeze(1) * masks  # B, C, N, L
        
        # waveform decoder
        output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1))  # B*C, 1, L
        output = output[:,:,self.stride:-(rest+self.stride)].contiguous()  # B*C, 1, L
        output = output.view(batch_size, self.num_spk, -1)  # B, C, T

        return output
コード例 #4
0
def tikhonov_filter(s, *, lmbda=1.0, npd=16, dtype=torch.float32):
    r"""Lowpass filter based on Tikhonov regularization.

    Lowpass filter image(s) and return low and high frequency
    components, consisting of the lowpass filtered image and its
    difference with the input image. The lowpass filter is equivalent to
    Tikhonov regularization with `lmbda` as the regularization parameter
    and a discrete gradient as the operator in the regularization term,
    i.e. the lowpass component is the solution to

    .. math::
      \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s}
      \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;,

    where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the
    regularization parameter, and :math:`G_i` is an operator that
    computes the discrete gradient along image axis :math:`i`. Once the
    lowpass component :math:`\mathbf{x}` has been computed, the highpass
    component is just :math:`\mathbf{s} - \mathbf{x}`.

    Parameters
    ----------
    s : array_like
      Input image or array of images.
    lmbda : float
      Regularization parameter controlling lowpass filtering.
    npd : int, optional (default=16)
      Number of samples to pad at image boundaries.

    Returns
    -------
    slp : array_like
      Lowpass image or array of images.
    shp : array_like
      Highpass image or array of images.
    """

    grv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([2, 1])).to(s.device)
    gcv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([1, 2])).to(s.device)
    fftopt = {"s": (s.shape[0] + 2 * npd, s.shape[1] + 2 * npd), "dim": (0, 1)}
    Gr = tfft.rfftn(grv, **fftopt)
    Gc = tfft.rfftn(gcv, **fftopt)
    A = 1.0 + lmbda * (torch.conj(Gr) * Gr + torch.conj(Gc) * Gc).real
    if s.ndim > 2:
        A = A[(slice(None), ) * 2 + (np.newaxis, ) * (s.ndim - 2)]
    fill = ((npd, npd), ) * 2 + ((0, 0), ) * (s.ndim - 2)
    snp = np.pad(s.cpu().numpy(), fill, 'symmetric')
    # sp = tpad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')
    sp = torch.from_numpy(snp).to(s.device)
    # sp = torch.from_numpy(np.pad(s.numpy(), ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric'))
    spshp = sp.shape
    sp = tfft.rfftn(sp, dim=(0, 1))
    sp /= A
    sp = tfft.irfftn(sp, s=spshp[0:2], dim=(0, 1))
    slp = sp[npd:(sp.shape[0] - npd), npd:(sp.shape[1] - npd)]
    shp = s - slp
    return slp, shp
コード例 #5
0
ファイル: gradient.py プロジェクト: al5250/multicore-mri
 def T_apply(self, y: Tensor) -> Tensor:
     _, _, H, W = y.size()
     if self.vertical:
         k = torch.arange(H, device=y.device).view(1, 1, -1, 1)
         x = torch.conj(1 - torch.exp(-2 * np.pi * 1j * k / H)) * y
     else:
         k = torch.arange(W, device=y.device).view(1, 1, -1, 1)
         x = torch.conj(1 - torch.exp(-2 * np.pi * 1j * k / W)) * y
     return x
コード例 #6
0
def tFT_pytorch(x, coilmaps):
    # inp: [nx, ny, ns]
    # out: [nx, ny]
    temp = torch.fft.ifftn(ifftshift(x, dim=(1,2)), dim=(1,2))

    temp_scoil = torch.sum(temp * torch.conj(coilmaps), axis=0)
    temp_scoil = temp_scoil / (torch.sum(coilmaps * torch.conj(coilmaps), axis=0))

    return temp_scoil
コード例 #7
0
ファイル: medi.py プロジェクト: Jinwei1209/Bayesian_QSM
    def CG_body(self, i, rTr, x, r, p):
        Ap = self.AtA(p)
        alpha = rTr / torch.sum(torch.conj(p) * Ap)

        x = x + p * alpha
        r = r - Ap * alpha
        rTrNew = torch.sum(torch.conj(r) * r)

        beta = rTrNew /  rTr
        p = r + p * beta
        return i+1, rTrNew, x, r, p
コード例 #8
0
def matexp(x, dt):
    """
    Calculates the matrix exponentiation for matrix of type -j * dt * [[0, tau*], [tau, 0]]
    """
    exp = torch.zeros(x.shape, dtype=torch.cdouble)
    taus = x[:, 1, 0]
    exp[:, 0, 0] = torch.cos(dt * torch.abs(taus))
    exp[:, 0, 1] = -1j * torch.conj(taus) * torch.sin(dt * torch.abs(taus)) / torch.abs(taus)
    exp[:, 1, 0] = -1j * torch.abs(taus) * torch.sin(dt * torch.abs(taus)) / torch.conj(taus)
    exp[:, 1, 1] = torch.cos(dt * torch.abs(taus))
    return exp
コード例 #9
0
ファイル: model.py プロジェクト: Zmjcc/DL_DV_pytorch
    def SMR_loss(self, y_true, y_pred):
        Nt = self.Nt
        Nr = self.Nr
        dk = self.dk
        K = self.K
        p = self.p
        sigma_2 = self.sigma_2
        batch_size = y_true.shape[0]
        #H_noiseless = torch.view_as_complex(y_true[:,:(2*Nt*Nr*K)].reshape((-1,Nt,Nr,2,K)).permute(0,1,2,4,3).contiguous())
        H = torch.view_as_complex(
            y_true.reshape((-1, Nt, Nr, 2, K)).permute(0, 1, 2, 4,
                                                       3).contiguous())

        # p_list_pred = y_pred[:, :K * dk].type_as(H)
        # q_list_pred = y_pred[:, K * dk:2 * K * dk].type_as(H)
        # mrt_list_pred = y_pred[:, -1:].type_as(H)
        #restore V
        V = torch.view_as_complex(
            y_pred.reshape((-1, Nt, dk, K, 2)).contiguous())
        '''precode matrix normalize'''
        V_flatten = V.reshape((-1, Nt * dk * K))
        energy_scale = torch.linalg.norm(V_flatten, axis=1).reshape(
            (-1, 1, 1, 1)).repeat(1, Nt, dk, K).type_as(H)
        V = V / energy_scale
        #V = self.DUU_EZF(H,p_list_pred,q_list_pred,mrt_list_pred)
        '''need to change for normal runing'''
        sum_rate = torch.zeros(1).cuda()
        for user in range(K):
            H_k = H[:, :, :, user].permute(0, 2, 1)
            V_k = V[:, :, :, user]
            signal_k = torch.matmul(H_k, V_k)
            signal_k_energy = torch.matmul(
                signal_k, torch.conj(signal_k.permute(0, 2, 1)))
            interference_k_energy = sigma_2 * torch.eye(Nr).type_as(H).reshape(
                (1, Nr, Nr)).repeat(batch_size, 1, 1)
            for j in range(K):
                if j != user:
                    V_j = V[:, :, :, j]
                    interference_j = torch.matmul(H_k, V_j)
                    interference_k_energy = interference_k_energy + torch.matmul(
                        interference_j,
                        torch.conj(interference_j.permute(0, 2, 1)))
                SINR_k = torch.matmul(signal_k_energy,
                                      torch.linalg.inv(interference_k_energy))
                rate_k = torch.log2(
                    complex_det(SINR_k + torch.eye(Nr).type_as(H).reshape(
                        (1, Nr, Nr)).repeat(batch_size, 1, 1)))
            sum_rate = sum_rate + rate_k
        sum_rate = -sum_rate
        #self.minus_sum_rate_loss(H.detach().cpu().numpy(), V.detach().cpu().numpy())
        return torch.mean(sum_rate)
コード例 #10
0
    def setdict(self, D=None):
        """Set dictionary array."""
        # Change the dictionary and its Fourier transform
        if D:
            self.D = D.device(device, non_blocking=True)
            self.Df = torch.fft.rfftn(self.D, **self.tensoropt)

        # Compute D^H S
        self.DSf = torch.conj(self.Df) * self.Sf
        if self.cri.Cd > 1:
            self.DSf = torch.sum(self.DSf, dim=self.cri.axisC, keepdim=True)
        if self.opt['HighMemSolve'] and self.cri.Cd == 1:
            self.c = solvedbi_sm_c(self.Df, torch.conj(self.Df), self.rho,
                                   self.cri.axisM)
        else:
            self.c = None
コード例 #11
0
    def detect(self, img):
        p = self.pre_process(img)
        if self.features_extractor in [
                "resnet", "mobilenet", "vgg16", "alexnet"
        ]:
            inp = torch.from_numpy(p).unsqueeze(dim=0).float().to(self.device)
            features = self.model(inp)
            feature_maps = features.squeeze().detach()
            feature_maps_hann = self.pos_process(feature_maps)
            del inp
            del feature_maps

        self.X = torch.fft.fftn(feature_maps_hann)

        F = self.A / self.B + self.lambda_
        Y = self.X * torch.conj(F)
        self.g = torch.fft.ifftn(torch.sum(Y, dim=0))
        g_cpu = self.g.detach().cpu().numpy()
        loc = np.unravel_index(np.argmax(g_cpu), g_cpu.shape)
        rows = int(loc[0] * self.roi.height / self.X.shape[-2])
        cols = int(loc[1] * self.roi.width / self.X.shape[-1])

        self.bbox, self.roi = transf2ori(
            (rows, cols), self.bbox, self.roi,
            img.shape[1:])  #transform to the ori frame
コード例 #12
0
ファイル: transform.py プロジェクト: aisari/torchsar
def ct2rt(x, axis=0):
    r"""Converts a complex-valued tensor to a real-valued tensor

    Converts a complex-valued tensor :math:`{\bf x}` to a real-valued tensor with FFT and conjugate symmetry.


    Parameters
    ----------
    x : Tensor
        The input tensor :math:`{\bf x}\in {\mathbb C}^{H×W}`.
    axis : int
        The axis for excuting FFT.

    Returns
    -------
    Tensor
        The output tensor :math:`{\bf y}\in {\mathbb R}^{2H×W}` ( :attr:`axis` = 0 ), :math:`{\bf y}\in {\mathbb R}^{H×2W}` ( :attr:`axis` = 1 )
    """

    d = x.dim()
    n = x.shape[axis]
    X = th.fft.fft(x, axis=axis)
    X0 = X[sl(d, axis, [[0]])]
    X1 = th.conj(X[sl(d, axis, range(n - 1, 0, -1))])
    Y = th.cat((X, X0.imag, X1), dim=axis)
    Y[sl(d, axis, [[0]])] = X0.real + 0j
    del x, X, X1
    y = th.fft.ifft(Y, axis=axis)
    return y
コード例 #13
0
ファイル: tensor_ops.py プロジェクト: malfet/pytorch
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(2, 4, 2)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return (
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(y, i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
     )
コード例 #14
0
def circular_correlation(
    a: torch.FloatTensor,
    b: torch.FloatTensor,
) -> torch.FloatTensor:
    """
    Compute the circular correlation between to vectors.

    .. note ::
        The implementation uses FFT.

    :param a: shape: s_1
        The tensor with the first vectors.
    :param b:
        The tensor with the second vectors.

    :return:
        The circular correlation between the vectors.
    """
    # Circular correlation of entity embeddings
    a_fft = rfft(a, dim=-1)
    b_fft = rfft(b, dim=-1)
    # complex conjugate
    a_fft = torch.conj(a_fft)
    # Hadamard product in frequency domain
    p_fft = a_fft * b_fft
    # inverse real FFT
    return irfft(p_fft, n=a.shape[-1], dim=-1)
コード例 #15
0
ファイル: phase_proj.py プロジェクト: jkkronk/ddpPytorch
def _f_st(u, lmb, device):
    # soft thresholding
    uabs = torch.squeeze(torch.sqrt(torch.sum(u * torch.conj(u), dim=0)))
    tmp = 1 - lmb / (uabs + 1e-8)
    tmp[torch.abs(tmp) < 0] = 0
    uu = u * tile(tmp.unsqueeze(0), 0, u.shape[0], device)
    return uu
コード例 #16
0
ファイル: modl.py プロジェクト: utcsilab/deepinpy
    def forward(self, x):

        assert self.x_adj is not None, "x_adj not computed!"
        r = self.denoiser(x)

        if self.A.single_channel:
            # multiply with maps because they might not be all-ones, and they include the fftmod term
            maps = self.A.maps.squeeze(1)
            r_ft = fft_forw(r * maps)
            x_ft_ones = (self.inp + self.l2lam * r_ft) / (1 + self.l2lam)
            x_ft = x_ft_ones * (abs(self.A.mask) != 0) + r_ft * (abs(
                self.A.mask) == 0)
            x = torch.conj(maps) * fft_adj(x_ft)
            self.num_cg = 0
        else:
            cg_op = ConjGrad(self.x_adj + self.l2lam * r,
                             self.A.normal,
                             l2lam=self.l2lam,
                             max_iter=self.hparams.cg_max_iter,
                             eps=self.hparams.cg_eps,
                             verbose=False)
            x = cg_op.forward(x)
            self.num_cg = cg_op.num_cg

        return x
コード例 #17
0
def _fft_c2r(
    func_name: str,
    input: TensorLikeType,
    n: Optional[int],
    dim: int,
    norm: NormType,
    forward: bool,
) -> TensorLikeType:
    """Common code for performing any complex to real FFT (irfft or hfft)"""
    input = _maybe_promote_tensor_fft(input, require_complex=True)
    dims = (utils.canonicalize_dim(input.ndim, dim), )
    last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
    check(last_dim_size >= 1,
          lambda: f"Invalid number of data points ({n}) specified")

    if n is not None:
        input = _resize_fft_input(input,
                                  dims=dims,
                                  sizes=(last_dim_size // 2 + 1, ))

    if forward:
        input = torch.conj(input)

    output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
    return _apply_norm(output,
                       norm=norm,
                       signal_numel=last_dim_size,
                       forward=forward)
コード例 #18
0
    def forward(self, x, k, sf, sigma):
        '''
        x: tensor, NxCxWxH
        k: tensor, Nx(1,3)xwxh
        sf: integer, 1
        sigma: tensor, Nx1x1x1
        '''

        # initialization & pre-calculation
        w, h = x.shape[-2:]
        FB = p2o(k, (w * sf, h * sf))
        FBC = torch.conj(FB)
        F2B = torch.pow(torch.abs(FB), 2)
        STy = upsample(x, sf=sf)
        FBFy = FBC * torch.fft.fftn(STy, dim=(-2, -1))
        x = nn.functional.interpolate(x, scale_factor=sf, mode='nearest')

        # hyper-parameter, alpha & beta
        ab = self.h(
            torch.cat(
                (sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)),
                dim=1))

        # unfolding
        for i in range(self.n):

            x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i:i + 1, ...], sf)
            x = self.p(
                torch.cat((x, ab[:, i + self.n:i + self.n + 1, ...].repeat(
                    1, 1, x.size(2), x.size(3))),
                          dim=1))

        return x
コード例 #19
0
def hole_interaction(
    h: torch.FloatTensor,
    r: torch.FloatTensor,
    t: torch.FloatTensor,
) -> torch.FloatTensor:  # noqa: D102
    """Evaluate the HolE interaction function.

    :param h: shape: (batch_size, num_heads, 1, 1, dim)
        The head representations.
    :param r: shape: (batch_size, 1, num_relations, 1, dim)
        The relation representations.
    :param t: shape: (batch_size, 1, 1, num_tails, dim)
        The tail representations.

    :return: shape: (batch_size, num_heads, num_relations, num_tails)
        The scores.
    """
    # Circular correlation of entity embeddings
    a_fft = rfft(h, dim=-1)
    b_fft = rfft(t, dim=-1)

    # complex conjugate
    a_fft = torch.conj(a_fft)

    # Hadamard product in frequency domain
    p_fft = a_fft * b_fft

    # inverse real FFT, shape: (b, h, 1, t, d)
    composite = irfft(p_fft, n=h.shape[-1], dim=-1)

    # transpose composite: (b, h, 1, d, t)
    composite = composite.transpose(-2, -1)

    # inner product with relation embedding
    return (r @ composite).squeeze(dim=-2)
コード例 #20
0
ファイル: utils.py プロジェクト: Helmholtz-AI-Energy/HyDe
def _est_additive_noise(
    subdata: torch.Tensor, calculation_dtype: torch.dtype = torch.float
) -> Tuple[torch.Tensor, torch.Tensor]:
    # estimate the additive noise in the given data with a certain precision
    eps = 1e-6
    dim0data, dim1data = subdata.shape
    dtp = subdata.dtype
    subdata = subdata.to(dtype=calculation_dtype)
    w = torch.zeros(subdata.shape, dtype=calculation_dtype, device=subdata.device)
    ddp = subdata @ torch.conj(subdata).T
    hld = (ddp + eps) @ torch.eye(int(dim0data), dtype=calculation_dtype, device=subdata.device)
    ddpi = torch.inverse(hld)
    for i in range(dim0data):
        xx = ddpi - (torch.outer(ddpi[:, i], ddpi[i, :]) / ddpi[i, i])
        # XX = RRi - (RRi(:,i)*RRi(i,:))/RRi(i,i);
        ddpa = ddp[:, i]
        # RRa = RR(:,i);
        ddpa[i] = 0.0
        # RRa(i)=0; % this remove the effects of XX(:,i)
        beta = xx @ ddpa
        # beta = XX * RRa;
        beta[i] = 0
        # beta(i)=0; % this remove the effects of XX(i,:)
        w[i, :] = subdata[i, :] - (beta @ subdata)
    # ret = torch.diag(torch.diag(ddp / dim1data))
    # Rw=diag(diag(w*w'/N));
    # print("here", w.shape)
    hold2 = torch.matmul(w, w.T) / float(subdata.shape[1])
    ret = torch.diag(torch.diagonal(hold2))
    w = w.to(dtype=dtp)
    ret = ret.to(dtype=dtp)
    return w, ret
コード例 #21
0
def rotate_interaction(
    h: torch.FloatTensor,
    r: torch.FloatTensor,
    t: torch.FloatTensor,
) -> torch.FloatTensor:
    """Evaluate the RotatE interaction function.

    :param h: shape: (batch_size, num_heads, 1, 1, 2*dim)
        The head representations.
    :param r: shape: (batch_size, 1, num_relations, 1, 2*dim)
        The relation representations.
    :param t: shape: (batch_size, 1, 1, num_tails, 2*dim)
        The tail representations.

    :return: shape: (batch_size, num_heads, num_relations, num_tails)
        The scores.
    """
    # r expresses a rotation in complex plane.
    h, r, t = [view_complex(x) for x in (h, r, t)]
    if estimate_cost_of_sequence(h.shape, r.shape) < estimate_cost_of_sequence(r.shape, t.shape):
        # rotate head by relation (=Hadamard product in complex space)
        h = h * r
    else:
        # rotate tail by inverse of relation
        # The inverse rotation is expressed by the complex conjugate of r.
        # The score is computed as the distance of the relation-rotated head to the tail.
        # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e.
        # |h * r - t| = |h - conj(r) * t|
        t = t * torch.conj(r)

    # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed
    return negative_norm(h - t, p=2, power_norm=False)
コード例 #22
0
def wiener_filter(img, psf, k):
    """Apply Wiener filter on images.

    Args:
        img: Tensor of image of shape `(N x C x H x W)` where N is the batch_size, \
            C is the number of band, H is height and W is weight, containing the image data.
        psf: Tensor of shape `(N x C x H x W)`, representing the Point Spread Function.
        k: Tensor of shape `(N x 1)`, representing the Noise-to-Signal Ratio.

    Returns:
        Tensor of shape `(N x C x H x W)`. The deconvolved image data.
    """

    img2 = torch.clone(img)
    img_fft = torch.fft.fft2(img2)
    psf_fft = torch.fft.fft2(psf)
    batch_size, _, m, n = img2.shape

    laps = np.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]])
    m1 = (m - 3) // 2
    n1 = (n - 3) // 2
    laps = np.pad(laps, [[m1, m - m1 - 3], [n1, n - n1 - 3]])
    laps = torch.from_numpy(laps)
    laps_fft = torch.fft.fft2(laps)

    k = k.reshape(batch_size, 1, 1, 1)
    f = torch.conj(psf_fft) / (torch.abs(psf_fft)**2 +
                               k * torch.abs(laps_fft)**2)
    m = f * img_fft
    return torch.fft.fftshift(torch.fft.ifft2(m).real)
コード例 #23
0
ファイル: feature_extractor.py プロジェクト: SongJae/SELD
def foa_intensity_vectors(complex_specs: torch.Tensor) -> torch.Tensor:
    if not torch.is_complex(complex_specs):
        complex_specs = torch.view_as_complex(complex_specs)

    # complex_specs: [chan, freq, time]
    IVx = torch.real(torch.conj(complex_specs[0]) * complex_specs[3])
    IVy = torch.real(torch.conj(complex_specs[0]) * complex_specs[1])
    IVz = torch.real(torch.conj(complex_specs[0]) * complex_specs[2])

    norm = torch.sqrt(IVx**2 + IVy**2 + IVz**2)
    IVx = IVx / norm
    IVy = IVy / norm
    IVz = IVz / norm

    # apply mel matrix without db ...
    return torch.stack([IVx, IVy, IVz], axis=0)
コード例 #24
0
def compute_tke_spectrum_pytorch(u, v, w, lx, ly, lz, smooth):
    import torch.fft
    nx = len(u[:, 0, 0])
    ny = len(v[0, :, 0])
    nz = len(w[0, 0, :])

    nt = nx * ny * nz
    n = nx  #int(np.round(np.power(nt,1.0/3.0)))

    uh = torch.fft.fft(u) / nt
    vh = torch.fft.fft(v) / nt
    wh = torch.fft.fft(w) / nt

    tkeh = torch.zeros((nx, ny, nz))
    tkeh = 0.5 * (uh * torch.conj(uh) + vh * torch.conj(vh) +
                  wh * torch.conj(wh)).real

    k0x = 2.0 * pi / lx
    k0y = 2.0 * pi / ly
    k0z = 2.0 * pi / lz

    knorm = (k0x + k0y + k0z) / 3.0

    kxmax = nx / 2
    kymax = ny / 2
    kzmax = nz / 2

    wave_numbers = knorm * torch.arange(0, n)

    tke_spectrum = torch.zeros([len(wave_numbers)])
    ks = get_ks(nx, ny, nz, kxmax, kymax, kzmax, "cuda:0")
    for k in range(0, min(len(tke_spectrum), ks.max())):
        tke_spectrum[k] = torch.sum(tkeh[ks == k]).item()
    #tkeh = tkeh.cpu().numpy()

    tke_spectrum = tke_spectrum / knorm
    #  tke_spectrum = tke_spectrum[1:]
    #  wave_numbers = wave_numbers[1:]
    if smooth:
        tkespecsmooth = movingaverage(tke_spectrum, 5)  #smooth the spectrum
        tkespecsmooth[0:4] = tke_spectrum[
            0:4]  # get the first 4 values from the original data
        tke_spectrum = tkespecsmooth

    knyquist = knorm * min(nx, ny, nz) / 2

    return knyquist, wave_numbers, tke_spectrum
コード例 #25
0
ファイル: compute_kernel.py プロジェクト: wuxiaoxue/pykeen
def _complex_native_complex(
    h: torch.FloatTensor,
    r: torch.FloatTensor,
    t: torch.FloatTensor,
) -> torch.FloatTensor:
    """Use torch built-ins for computation with complex numbers."""
    h, r, t = [view_complex(x=x) for x in (h, r, t)]
    return torch.real(tensor_product(h, r, torch.conj(t)).sum(dim=-1))
コード例 #26
0
def conj(input_):
    """Wrapper of `torch.conj`.

    Parameters
    ----------
    input_ : DTensor
        Input tensor.
    """
    return torch.conj(input_._data)
コード例 #27
0
 def forward(self, h, r, t):
     h_e, r_e, t_e = self.embed(h, r, t)
     r_e = F.normalize(r_e, p=2, dim=-1)
     h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
     t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
     e, _ = torch.unbind(
         torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1),
         -1)
     return -F.sigmoid(torch.sum(r_e * e, 1))
コード例 #28
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
コード例 #29
0
def conj(X):

    if th.is_complex(X):
        return th.conj(X)
    elif X.size(-1) == 2:
        return th.stack((X[..., 0], -X[..., 1]), dim=-1)
    else:
        raise TypeError(
            'Not known type! Only real and imag representions are supported!')
コード例 #30
0
    def decod_signal(self, signal, pulse_width, t, t_window):
        '''
        Takes as input symmetric pulses (negative and positive time),
        but work only with pulses in POSITIVE time (without symmetric at zero pulse )

        Parameters
        ----------
        signal : TYPE: torch.complex128 tensor of shape [batch_size, dim_z, dim_t].
            DESCRIPTION: Output of the split-step solution.
        pulse_width : TYPE: int
            DESCRIPTION: Pulse width.
        t : TYPE: torch.float32 tensor of shape [dim_t]
            DESCRIPTION: Time points. The boundaries of this vector are taken
            in such a way that the signal broadened as it propagates does not
            go beyond the calculation boundaries
        t_window : TYPE: torch.int64 tensor of shap [2] or (int, int )
            DESCRIPTION: Contain t_start and t_end to select positive
            time with pulses from t. 
        Returns
        -------
        t_dec : TYPE: torch.float32 tensor of shape [dim_t_dec]
            DESCRIPTION: positive time when there are pulses
        signal_decoded : TYPE: torch.float64 tensor of shape [batch_size,dim_z,dim_t_dec]
            DESCRIPTION: decoded signal
        '''

        # saving divice
        device = signal.device

        #cutting the time (we work only with positive time,
        # without symmetric pulse at zero)
        t = t.to(signal.device)
        T = pulse_width
        t_start, t_end = t_window
        t_start = torch.argmin(torch.abs(t - 0))
        t_dec = t[t_start:t_end]
        signal = signal[:, :, t_start:t_end]

        #preparation
        start_pulse = torch.argmin(torch.abs(t_dec - 0.5 * T))
        end_pulse = torch.argmin(torch.abs(t_dec - 1.5 * T))
        w_pulse = end_pulse - start_pulse

        #take date without symmetric at zero  pulse
        u = torch.zeros_like(signal).to(device)
        u[:, :, start_pulse:t_end] = signal[:, :, start_pulse:t_end]

        u_shifted = torch.zeros_like(u).to(device)
        u_shifted[:, :, start_pulse:-w_pulse] = u[:, :, end_pulse:]

        #decoding
        signal_decoded = (u + u_shifted)
        # signal_decoded = (u + u_shifted)/2
        signal_decoded = signal_decoded * torch.conj(signal_decoded)

        return u, u_shifted, signal_decoded.real, t_dec