def sense_estimation_ls(y, X, basis_funct, uspat): """ We estimate the bias field with a polynomial basis of the given order, using least squares method :param data: data (Fourier) [n x m, c] :param x_estimate: predicted reconstruction estimate [n x m, c] (needed to compute recon_error!) :param max_basis_order: :param ls_threshold: :param max_bias_eval: :return: """ num_coils, sizex, sizey = y.shape num_coeffs = basis_funct.shape[1] coeff_coils = torch.zeros((num_coils, num_coeffs), dtype=torch.cfloat, device=y.real.device) # XA - Y = 0 for i in range(num_coils): Y = y[i, :, :].reshape(sizex * sizey) A = UFT(X, uspat, basis_funct[i, :, :, :]).reshape(num_coeffs, sizex * sizey) coeff = torch.matmul( torch.matmul(Y, torch.transpose(torch.conj(A), 0, 1)), complex_inverse( torch.matmul(A, torch.transpose(torch.conj(A), 0, 1)))) coeff_coils[i, :] = coeff.clone() del Y del A del coeff return coeff_coils
def xstep(self): r"""Minimise Augmented Lagrangian with respect to :math:`\mathbf{x}`.""" self.YU[:] = self.Y - self.U b = self.DSf + self.rho * torch.fft.rfftn(self.YU, **self.fftopt) if self.cri.Cd == 1: self.Xf[:] = solvedbi_sm(self.Df, self.rho, b, self.c, self.cri.axisM) else: self.Xf[:] = solvemdbi_ism(self.Df, self.rho, b, self.cri.axisM, self.cri.axisC) self.X = torch.fft.irfftn(self.Xf, **self.fftopt) if self.opt['LinSolveCheck']: Dop = lambda x: torch.sum(self.Df * x, dim=self.cri.axisM) if self.cri.Cd == 1: DHop = lambda x: torch.conj(self.Df) * x else: DHop = lambda x: torch.sum(torch.conj(self.Df) * x, dim=self.cri.axisC) ax = DHop(Dop(self.Xf)) + self.rho * self.Xf self.xrrs = rrs(ax, b) else: self.xrrs = None
def forward(self, input, angle): # padding mag, ph, real, image= self.stft.transform(input.reshape(-1, input.size()[-1])) pad = Variable(torch.zeros(mag.size()[0],mag.size()[1], 1)).type(input.type()) mag = torch.cat([mag, pad], -1) ph = torch.cat([ph, pad], -1) output, rest = self.pad_signal(input) enc_output = self.encoder(output[:, :1]) # B, N, L mag = mag.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1)) ph = ph.view(enc_output.size(0), self.n_mic, -1, enc_output.size(-1)) LPS = 10 * torch.log10(mag ** 2 + 10e-20) complex = (mag * torch.exp(ph * 1j)) IPD_list = [] for m in self.pairs: com_u1 = complex[:, m[0]] com_u2 = complex[:, m[1]] IPD = torch.angle(com_u1 * torch.conj(com_u2)) IPD /= (self.frequency_vector + 1.0)[:, None] IPD = IPD % torch.pi IPD = IPD.unsqueeze(dim=1) IPD_list.append(IPD) IPD = torch.cat(IPD_list, dim=1) steering_vector = self.__get_steering_vector(angle, self.pairs) steering_vector = steering_vector.unsqueeze(dim=-1) AF = steering_vector * IPD AF = AF/AF.sum(dim=1, keepdims=True).real w = self.w.unsqueeze(dim=0).expand(AF.size()[0], -1, -1, -1) dpr = torch.zeros((AF.size(0), self.n_grid, AF.size(-2), AF.size(-1)), dtype=torch.complex128) print(w.size()) print(complex.size()) exit() for i in range(36): for j in range(602): for h in range(97): dpr[:, i, h, j] = (w[:, :, i, h] * complex[:, :, h, j]).sum(dim=1) dpr = (dpr * torch.conj(dpr))/ torch.sum(dpr * torch.conj(dpr), dim=1, keepdim=True) print(dpr.size()) print(AF.size()) feature_list = [enc_output.unsqueeze(dim=1), AF, dpr, torch.cos(IPD)] fusion = torch.cat(feature_list, dim=1).float() batch_size = output.size(0) fusion = fusion.view(batch_size, -1, fusion.size()[-1]) # waveform encoder masks = torch.sigmoid(self.TCN(fusion)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L # waveform decoder output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L output = output.view(batch_size, self.num_spk, -1) # B, C, T return output
def tikhonov_filter(s, *, lmbda=1.0, npd=16, dtype=torch.float32): r"""Lowpass filter based on Tikhonov regularization. Lowpass filter image(s) and return low and high frequency components, consisting of the lowpass filtered image and its difference with the input image. The lowpass filter is equivalent to Tikhonov regularization with `lmbda` as the regularization parameter and a discrete gradient as the operator in the regularization term, i.e. the lowpass component is the solution to .. math:: \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s} \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;, where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the regularization parameter, and :math:`G_i` is an operator that computes the discrete gradient along image axis :math:`i`. Once the lowpass component :math:`\mathbf{x}` has been computed, the highpass component is just :math:`\mathbf{s} - \mathbf{x}`. Parameters ---------- s : array_like Input image or array of images. lmbda : float Regularization parameter controlling lowpass filtering. npd : int, optional (default=16) Number of samples to pad at image boundaries. Returns ------- slp : array_like Lowpass image or array of images. shp : array_like Highpass image or array of images. """ grv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([2, 1])).to(s.device) gcv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([1, 2])).to(s.device) fftopt = {"s": (s.shape[0] + 2 * npd, s.shape[1] + 2 * npd), "dim": (0, 1)} Gr = tfft.rfftn(grv, **fftopt) Gc = tfft.rfftn(gcv, **fftopt) A = 1.0 + lmbda * (torch.conj(Gr) * Gr + torch.conj(Gc) * Gc).real if s.ndim > 2: A = A[(slice(None), ) * 2 + (np.newaxis, ) * (s.ndim - 2)] fill = ((npd, npd), ) * 2 + ((0, 0), ) * (s.ndim - 2) snp = np.pad(s.cpu().numpy(), fill, 'symmetric') # sp = tpad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric') sp = torch.from_numpy(snp).to(s.device) # sp = torch.from_numpy(np.pad(s.numpy(), ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')) spshp = sp.shape sp = tfft.rfftn(sp, dim=(0, 1)) sp /= A sp = tfft.irfftn(sp, s=spshp[0:2], dim=(0, 1)) slp = sp[npd:(sp.shape[0] - npd), npd:(sp.shape[1] - npd)] shp = s - slp return slp, shp
def T_apply(self, y: Tensor) -> Tensor: _, _, H, W = y.size() if self.vertical: k = torch.arange(H, device=y.device).view(1, 1, -1, 1) x = torch.conj(1 - torch.exp(-2 * np.pi * 1j * k / H)) * y else: k = torch.arange(W, device=y.device).view(1, 1, -1, 1) x = torch.conj(1 - torch.exp(-2 * np.pi * 1j * k / W)) * y return x
def tFT_pytorch(x, coilmaps): # inp: [nx, ny, ns] # out: [nx, ny] temp = torch.fft.ifftn(ifftshift(x, dim=(1,2)), dim=(1,2)) temp_scoil = torch.sum(temp * torch.conj(coilmaps), axis=0) temp_scoil = temp_scoil / (torch.sum(coilmaps * torch.conj(coilmaps), axis=0)) return temp_scoil
def CG_body(self, i, rTr, x, r, p): Ap = self.AtA(p) alpha = rTr / torch.sum(torch.conj(p) * Ap) x = x + p * alpha r = r - Ap * alpha rTrNew = torch.sum(torch.conj(r) * r) beta = rTrNew / rTr p = r + p * beta return i+1, rTrNew, x, r, p
def matexp(x, dt): """ Calculates the matrix exponentiation for matrix of type -j * dt * [[0, tau*], [tau, 0]] """ exp = torch.zeros(x.shape, dtype=torch.cdouble) taus = x[:, 1, 0] exp[:, 0, 0] = torch.cos(dt * torch.abs(taus)) exp[:, 0, 1] = -1j * torch.conj(taus) * torch.sin(dt * torch.abs(taus)) / torch.abs(taus) exp[:, 1, 0] = -1j * torch.abs(taus) * torch.sin(dt * torch.abs(taus)) / torch.conj(taus) exp[:, 1, 1] = torch.cos(dt * torch.abs(taus)) return exp
def SMR_loss(self, y_true, y_pred): Nt = self.Nt Nr = self.Nr dk = self.dk K = self.K p = self.p sigma_2 = self.sigma_2 batch_size = y_true.shape[0] #H_noiseless = torch.view_as_complex(y_true[:,:(2*Nt*Nr*K)].reshape((-1,Nt,Nr,2,K)).permute(0,1,2,4,3).contiguous()) H = torch.view_as_complex( y_true.reshape((-1, Nt, Nr, 2, K)).permute(0, 1, 2, 4, 3).contiguous()) # p_list_pred = y_pred[:, :K * dk].type_as(H) # q_list_pred = y_pred[:, K * dk:2 * K * dk].type_as(H) # mrt_list_pred = y_pred[:, -1:].type_as(H) #restore V V = torch.view_as_complex( y_pred.reshape((-1, Nt, dk, K, 2)).contiguous()) '''precode matrix normalize''' V_flatten = V.reshape((-1, Nt * dk * K)) energy_scale = torch.linalg.norm(V_flatten, axis=1).reshape( (-1, 1, 1, 1)).repeat(1, Nt, dk, K).type_as(H) V = V / energy_scale #V = self.DUU_EZF(H,p_list_pred,q_list_pred,mrt_list_pred) '''need to change for normal runing''' sum_rate = torch.zeros(1).cuda() for user in range(K): H_k = H[:, :, :, user].permute(0, 2, 1) V_k = V[:, :, :, user] signal_k = torch.matmul(H_k, V_k) signal_k_energy = torch.matmul( signal_k, torch.conj(signal_k.permute(0, 2, 1))) interference_k_energy = sigma_2 * torch.eye(Nr).type_as(H).reshape( (1, Nr, Nr)).repeat(batch_size, 1, 1) for j in range(K): if j != user: V_j = V[:, :, :, j] interference_j = torch.matmul(H_k, V_j) interference_k_energy = interference_k_energy + torch.matmul( interference_j, torch.conj(interference_j.permute(0, 2, 1))) SINR_k = torch.matmul(signal_k_energy, torch.linalg.inv(interference_k_energy)) rate_k = torch.log2( complex_det(SINR_k + torch.eye(Nr).type_as(H).reshape( (1, Nr, Nr)).repeat(batch_size, 1, 1))) sum_rate = sum_rate + rate_k sum_rate = -sum_rate #self.minus_sum_rate_loss(H.detach().cpu().numpy(), V.detach().cpu().numpy()) return torch.mean(sum_rate)
def setdict(self, D=None): """Set dictionary array.""" # Change the dictionary and its Fourier transform if D: self.D = D.device(device, non_blocking=True) self.Df = torch.fft.rfftn(self.D, **self.tensoropt) # Compute D^H S self.DSf = torch.conj(self.Df) * self.Sf if self.cri.Cd > 1: self.DSf = torch.sum(self.DSf, dim=self.cri.axisC, keepdim=True) if self.opt['HighMemSolve'] and self.cri.Cd == 1: self.c = solvedbi_sm_c(self.Df, torch.conj(self.Df), self.rho, self.cri.axisM) else: self.c = None
def detect(self, img): p = self.pre_process(img) if self.features_extractor in [ "resnet", "mobilenet", "vgg16", "alexnet" ]: inp = torch.from_numpy(p).unsqueeze(dim=0).float().to(self.device) features = self.model(inp) feature_maps = features.squeeze().detach() feature_maps_hann = self.pos_process(feature_maps) del inp del feature_maps self.X = torch.fft.fftn(feature_maps_hann) F = self.A / self.B + self.lambda_ Y = self.X * torch.conj(F) self.g = torch.fft.ifftn(torch.sum(Y, dim=0)) g_cpu = self.g.detach().cpu().numpy() loc = np.unravel_index(np.argmax(g_cpu), g_cpu.shape) rows = int(loc[0] * self.roi.height / self.X.shape[-2]) cols = int(loc[1] * self.roi.width / self.X.shape[-1]) self.bbox, self.roi = transf2ori( (rows, cols), self.bbox, self.roi, img.shape[1:]) #transform to the ori frame
def ct2rt(x, axis=0): r"""Converts a complex-valued tensor to a real-valued tensor Converts a complex-valued tensor :math:`{\bf x}` to a real-valued tensor with FFT and conjugate symmetry. Parameters ---------- x : Tensor The input tensor :math:`{\bf x}\in {\mathbb C}^{H×W}`. axis : int The axis for excuting FFT. Returns ------- Tensor The output tensor :math:`{\bf y}\in {\mathbb R}^{2H×W}` ( :attr:`axis` = 0 ), :math:`{\bf y}\in {\mathbb R}^{H×2W}` ( :attr:`axis` = 1 ) """ d = x.dim() n = x.shape[axis] X = th.fft.fft(x, axis=axis) X0 = X[sl(d, axis, [[0]])] X1 = th.conj(X[sl(d, axis, range(n - 1, 0, -1))]) Y = th.cat((X, X0.imag, X1), dim=axis) Y[sl(d, axis, [[0]])] = X0.real + 0j del x, X, X1 y = th.fft.ifft(Y, axis=axis) return y
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(2, 4, 2) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return ( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(y, i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), )
def circular_correlation( a: torch.FloatTensor, b: torch.FloatTensor, ) -> torch.FloatTensor: """ Compute the circular correlation between to vectors. .. note :: The implementation uses FFT. :param a: shape: s_1 The tensor with the first vectors. :param b: The tensor with the second vectors. :return: The circular correlation between the vectors. """ # Circular correlation of entity embeddings a_fft = rfft(a, dim=-1) b_fft = rfft(b, dim=-1) # complex conjugate a_fft = torch.conj(a_fft) # Hadamard product in frequency domain p_fft = a_fft * b_fft # inverse real FFT return irfft(p_fft, n=a.shape[-1], dim=-1)
def _f_st(u, lmb, device): # soft thresholding uabs = torch.squeeze(torch.sqrt(torch.sum(u * torch.conj(u), dim=0))) tmp = 1 - lmb / (uabs + 1e-8) tmp[torch.abs(tmp) < 0] = 0 uu = u * tile(tmp.unsqueeze(0), 0, u.shape[0], device) return uu
def forward(self, x): assert self.x_adj is not None, "x_adj not computed!" r = self.denoiser(x) if self.A.single_channel: # multiply with maps because they might not be all-ones, and they include the fftmod term maps = self.A.maps.squeeze(1) r_ft = fft_forw(r * maps) x_ft_ones = (self.inp + self.l2lam * r_ft) / (1 + self.l2lam) x_ft = x_ft_ones * (abs(self.A.mask) != 0) + r_ft * (abs( self.A.mask) == 0) x = torch.conj(maps) * fft_adj(x_ft) self.num_cg = 0 else: cg_op = ConjGrad(self.x_adj + self.l2lam * r, self.A.normal, l2lam=self.l2lam, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) self.num_cg = cg_op.num_cg return x
def _fft_c2r( func_name: str, input: TensorLikeType, n: Optional[int], dim: int, norm: NormType, forward: bool, ) -> TensorLikeType: """Common code for performing any complex to real FFT (irfft or hfft)""" input = _maybe_promote_tensor_fft(input, require_complex=True) dims = (utils.canonicalize_dim(input.ndim, dim), ) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") if n is not None: input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1, )) if forward: input = torch.conj(input) output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
def forward(self, x, k, sf, sigma): ''' x: tensor, NxCxWxH k: tensor, Nx(1,3)xwxh sf: integer, 1 sigma: tensor, Nx1x1x1 ''' # initialization & pre-calculation w, h = x.shape[-2:] FB = p2o(k, (w * sf, h * sf)) FBC = torch.conj(FB) F2B = torch.pow(torch.abs(FB), 2) STy = upsample(x, sf=sf) FBFy = FBC * torch.fft.fftn(STy, dim=(-2, -1)) x = nn.functional.interpolate(x, scale_factor=sf, mode='nearest') # hyper-parameter, alpha & beta ab = self.h( torch.cat( (sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)), dim=1)) # unfolding for i in range(self.n): x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i:i + 1, ...], sf) x = self.p( torch.cat((x, ab[:, i + self.n:i + self.n + 1, ...].repeat( 1, 1, x.size(2), x.size(3))), dim=1)) return x
def hole_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: # noqa: D102 """Evaluate the HolE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # Circular correlation of entity embeddings a_fft = rfft(h, dim=-1) b_fft = rfft(t, dim=-1) # complex conjugate a_fft = torch.conj(a_fft) # Hadamard product in frequency domain p_fft = a_fft * b_fft # inverse real FFT, shape: (b, h, 1, t, d) composite = irfft(p_fft, n=h.shape[-1], dim=-1) # transpose composite: (b, h, 1, d, t) composite = composite.transpose(-2, -1) # inner product with relation embedding return (r @ composite).squeeze(dim=-2)
def _est_additive_noise( subdata: torch.Tensor, calculation_dtype: torch.dtype = torch.float ) -> Tuple[torch.Tensor, torch.Tensor]: # estimate the additive noise in the given data with a certain precision eps = 1e-6 dim0data, dim1data = subdata.shape dtp = subdata.dtype subdata = subdata.to(dtype=calculation_dtype) w = torch.zeros(subdata.shape, dtype=calculation_dtype, device=subdata.device) ddp = subdata @ torch.conj(subdata).T hld = (ddp + eps) @ torch.eye(int(dim0data), dtype=calculation_dtype, device=subdata.device) ddpi = torch.inverse(hld) for i in range(dim0data): xx = ddpi - (torch.outer(ddpi[:, i], ddpi[i, :]) / ddpi[i, i]) # XX = RRi - (RRi(:,i)*RRi(i,:))/RRi(i,i); ddpa = ddp[:, i] # RRa = RR(:,i); ddpa[i] = 0.0 # RRa(i)=0; % this remove the effects of XX(:,i) beta = xx @ ddpa # beta = XX * RRa; beta[i] = 0 # beta(i)=0; % this remove the effects of XX(i,:) w[i, :] = subdata[i, :] - (beta @ subdata) # ret = torch.diag(torch.diag(ddp / dim1data)) # Rw=diag(diag(w*w'/N)); # print("here", w.shape) hold2 = torch.matmul(w, w.T) / float(subdata.shape[1]) ret = torch.diag(torch.diagonal(hold2)) w = w.to(dtype=dtp) ret = ret.to(dtype=dtp) return w, ret
def rotate_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the RotatE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, 2*dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, 2*dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, 2*dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # r expresses a rotation in complex plane. h, r, t = [view_complex(x) for x in (h, r, t)] if estimate_cost_of_sequence(h.shape, r.shape) < estimate_cost_of_sequence(r.shape, t.shape): # rotate head by relation (=Hadamard product in complex space) h = h * r else: # rotate tail by inverse of relation # The inverse rotation is expressed by the complex conjugate of r. # The score is computed as the distance of the relation-rotated head to the tail. # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. # |h * r - t| = |h - conj(r) * t| t = t * torch.conj(r) # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed return negative_norm(h - t, p=2, power_norm=False)
def wiener_filter(img, psf, k): """Apply Wiener filter on images. Args: img: Tensor of image of shape `(N x C x H x W)` where N is the batch_size, \ C is the number of band, H is height and W is weight, containing the image data. psf: Tensor of shape `(N x C x H x W)`, representing the Point Spread Function. k: Tensor of shape `(N x 1)`, representing the Noise-to-Signal Ratio. Returns: Tensor of shape `(N x C x H x W)`. The deconvolved image data. """ img2 = torch.clone(img) img_fft = torch.fft.fft2(img2) psf_fft = torch.fft.fft2(psf) batch_size, _, m, n = img2.shape laps = np.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]) m1 = (m - 3) // 2 n1 = (n - 3) // 2 laps = np.pad(laps, [[m1, m - m1 - 3], [n1, n - n1 - 3]]) laps = torch.from_numpy(laps) laps_fft = torch.fft.fft2(laps) k = k.reshape(batch_size, 1, 1, 1) f = torch.conj(psf_fft) / (torch.abs(psf_fft)**2 + k * torch.abs(laps_fft)**2) m = f * img_fft return torch.fft.fftshift(torch.fft.ifft2(m).real)
def foa_intensity_vectors(complex_specs: torch.Tensor) -> torch.Tensor: if not torch.is_complex(complex_specs): complex_specs = torch.view_as_complex(complex_specs) # complex_specs: [chan, freq, time] IVx = torch.real(torch.conj(complex_specs[0]) * complex_specs[3]) IVy = torch.real(torch.conj(complex_specs[0]) * complex_specs[1]) IVz = torch.real(torch.conj(complex_specs[0]) * complex_specs[2]) norm = torch.sqrt(IVx**2 + IVy**2 + IVz**2) IVx = IVx / norm IVy = IVy / norm IVz = IVz / norm # apply mel matrix without db ... return torch.stack([IVx, IVy, IVz], axis=0)
def compute_tke_spectrum_pytorch(u, v, w, lx, ly, lz, smooth): import torch.fft nx = len(u[:, 0, 0]) ny = len(v[0, :, 0]) nz = len(w[0, 0, :]) nt = nx * ny * nz n = nx #int(np.round(np.power(nt,1.0/3.0))) uh = torch.fft.fft(u) / nt vh = torch.fft.fft(v) / nt wh = torch.fft.fft(w) / nt tkeh = torch.zeros((nx, ny, nz)) tkeh = 0.5 * (uh * torch.conj(uh) + vh * torch.conj(vh) + wh * torch.conj(wh)).real k0x = 2.0 * pi / lx k0y = 2.0 * pi / ly k0z = 2.0 * pi / lz knorm = (k0x + k0y + k0z) / 3.0 kxmax = nx / 2 kymax = ny / 2 kzmax = nz / 2 wave_numbers = knorm * torch.arange(0, n) tke_spectrum = torch.zeros([len(wave_numbers)]) ks = get_ks(nx, ny, nz, kxmax, kymax, kzmax, "cuda:0") for k in range(0, min(len(tke_spectrum), ks.max())): tke_spectrum[k] = torch.sum(tkeh[ks == k]).item() #tkeh = tkeh.cpu().numpy() tke_spectrum = tke_spectrum / knorm # tke_spectrum = tke_spectrum[1:] # wave_numbers = wave_numbers[1:] if smooth: tkespecsmooth = movingaverage(tke_spectrum, 5) #smooth the spectrum tkespecsmooth[0:4] = tke_spectrum[ 0:4] # get the first 4 values from the original data tke_spectrum = tkespecsmooth knyquist = knorm * min(nx, ny, nz) / 2 return knyquist, wave_numbers, tke_spectrum
def _complex_native_complex( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Use torch built-ins for computation with complex numbers.""" h, r, t = [view_complex(x=x) for x in (h, r, t)] return torch.real(tensor_product(h, r, torch.conj(t)).sum(dim=-1))
def conj(input_): """Wrapper of `torch.conj`. Parameters ---------- input_ : DTensor Input tensor. """ return torch.conj(input_._data)
def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) r_e = F.normalize(r_e, p=2, dim=-1) h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1) t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1) e, _ = torch.unbind( torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1) return -F.sigmoid(torch.sum(r_e * e, 1))
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def conj(X): if th.is_complex(X): return th.conj(X) elif X.size(-1) == 2: return th.stack((X[..., 0], -X[..., 1]), dim=-1) else: raise TypeError( 'Not known type! Only real and imag representions are supported!')
def decod_signal(self, signal, pulse_width, t, t_window): ''' Takes as input symmetric pulses (negative and positive time), but work only with pulses in POSITIVE time (without symmetric at zero pulse ) Parameters ---------- signal : TYPE: torch.complex128 tensor of shape [batch_size, dim_z, dim_t]. DESCRIPTION: Output of the split-step solution. pulse_width : TYPE: int DESCRIPTION: Pulse width. t : TYPE: torch.float32 tensor of shape [dim_t] DESCRIPTION: Time points. The boundaries of this vector are taken in such a way that the signal broadened as it propagates does not go beyond the calculation boundaries t_window : TYPE: torch.int64 tensor of shap [2] or (int, int ) DESCRIPTION: Contain t_start and t_end to select positive time with pulses from t. Returns ------- t_dec : TYPE: torch.float32 tensor of shape [dim_t_dec] DESCRIPTION: positive time when there are pulses signal_decoded : TYPE: torch.float64 tensor of shape [batch_size,dim_z,dim_t_dec] DESCRIPTION: decoded signal ''' # saving divice device = signal.device #cutting the time (we work only with positive time, # without symmetric pulse at zero) t = t.to(signal.device) T = pulse_width t_start, t_end = t_window t_start = torch.argmin(torch.abs(t - 0)) t_dec = t[t_start:t_end] signal = signal[:, :, t_start:t_end] #preparation start_pulse = torch.argmin(torch.abs(t_dec - 0.5 * T)) end_pulse = torch.argmin(torch.abs(t_dec - 1.5 * T)) w_pulse = end_pulse - start_pulse #take date without symmetric at zero pulse u = torch.zeros_like(signal).to(device) u[:, :, start_pulse:t_end] = signal[:, :, start_pulse:t_end] u_shifted = torch.zeros_like(u).to(device) u_shifted[:, :, start_pulse:-w_pulse] = u[:, :, end_pulse:] #decoding signal_decoded = (u + u_shifted) # signal_decoded = (u + u_shifted)/2 signal_decoded = signal_decoded * torch.conj(signal_decoded) return u, u_shifted, signal_decoded.real, t_dec