def _single_tensor_adadelta(params: List[Tensor], grads: List[Tensor], square_avgs: List[Tensor], acc_deltas: List[Tensor], *, lr: float, rho: float, eps: float, weight_decay: float, maximize: bool): for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas): grad = grad if not maximize else -grad if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): square_avg = torch.view_as_real(square_avg) acc_delta = torch.view_as_real(acc_delta) grad = torch.view_as_real(grad) square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) std = square_avg.add(eps).sqrt_() delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) if torch.is_complex(param): delta = torch.view_as_complex(delta) param.add_(delta, alpha=-lr)
def matmulcc(A, B): """Complex matrix multiplication like A * B in matlab Parameters ---------- A : {torch array} any size torch array, both complex and real representation are supported. For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part). B : {torch array} any size torch array, both complex and real representation are supported. For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part). Returns ------- torch array result of complex multiplication with the same repesentation as :attr:`A`. """ if th.is_complex(A) and th.is_complex(B): return th.matmul(A.real, B.real) - th.matmul(A.imag, B.imag) + 1j * ( th.matmul(A.real, B.imag) + th.matmul(A.imag, B.real)) else: return th.stack( (th.matmul(A[..., 0], B[..., 0]) - th.matmul(A[..., 1], B[..., 1]), th.matmul(A[..., 0], B[..., 1]) + th.matmul(A[..., 1], B[..., 0])), dim=-1)
def einsum(equation, *operands): # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support # mixed input with complex and real tensors. if len(operands) == 1: if isinstance(operands[0], (tuple, list)): operands = operands[0] complex_module = FC if isinstance(operands[0], ComplexTensor) else torch return complex_module.einsum(equation, *operands) elif len(operands) != 2: op0 = operands[0] same_type = all(op.dtype == op0.dtype for op in operands[1:]) if same_type: _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum return _einsum(equation, *operands) else: raise ValueError("0 or More than 2 operands are not supported.") a, b = operands if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): return FC.einsum(equation, a, b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if not torch.is_complex(a): o_real = torch.einsum(equation, a, b.real) o_imag = torch.einsum(equation, a, b.imag) return torch.complex(o_real, o_imag) elif not torch.is_complex(b): o_real = torch.einsum(equation, a.real, b) o_imag = torch.einsum(equation, a.imag, b) return torch.complex(o_real, o_imag) else: return torch.einsum(equation, a, b) else: return torch.einsum(equation, a, b)
def adadelta(params: List[Tensor], grads: List[Tensor], square_avgs: List[Tensor], acc_deltas: List[Tensor], *, lr: float, rho: float, eps: float, weight_decay: float): r"""Functional API that performs Adadelta algorithm computation. See :class:`~torch.optim.Adadelta` for details. """ for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas): if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): square_avg = torch.view_as_real(square_avg) acc_delta = torch.view_as_real(acc_delta) grad = torch.view_as_real(grad) square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) std = square_avg.add(eps).sqrt_() delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) if torch.is_complex(param): delta = torch.view_as_complex(delta) param.add_(delta, alpha=-lr)
def inner_product(val1: Tensor, val2: Tensor, dim: int = -1) -> Tensor: """Complex inner product. Args: val1: A tensor for the inner product. val2: A second tensor for the inner product. dim: An integer indicating the complex dimension (for real inputs only). Returns: The complex inner product of ``val1`` and ``val2``. """ if not val1.dtype == val2.dtype: raise ValueError("val1 has different dtype than val2.") if not torch.is_complex(val1): if not val1.shape[dim] == val2.shape[dim] == 2: raise ValueError( "Real input does not have dimension size 2 at dim.") inprod = conj_complex_mult(val2, val1, dim=dim) if not torch.is_complex(val1): inprod = torch.cat((inprod.select(dim, 0).sum().view(1), inprod.select(dim, 1).sum().view(1))) else: inprod = torch.sum(inprod) return inprod
def ebemulcc(A, B): """Element-by-element complex multiplication like A .* B in matlab Parameters ---------- A : {torch array} any size torch array, both complex and real representation are supported. For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part). B : {torch array} any size torch array, both complex and real representation are supported. For real representation, the last dimension is 2 (the first --> real part, the second --> imaginary part). :attr:`B` has the same size as :attr:`A`. Returns ------- torch array result of element-by-element complex multiplication with the same repesentation as :attr:`A`. """ if th.is_complex(A) and th.is_complex(B): return A.real * B.real - A.imag * B.imag + 1j * (A.real * B.imag + A.imag * B.real) else: return th.stack((A[..., 0] * B[..., 0] - A[..., 1] * B[..., 1], A[..., 0] * B[..., 1] + A[..., 1] * B[..., 0]), dim=-1)
def dist_proj(X, Y): Px = torch.matmul(X, torch.matmul(torch.inverse(torch.matmul(X.conj().t(), X)), X.conj().t())) Py = torch.matmul(Y, torch.matmul(torch.inverse(torch.matmul(Y.conj().t(), Y)), Y.conj().t())) if torch.is_complex(X) or torch.is_complex(Y): P = Px - Py return torch.sqrt(torch.sum(torch.matmul(P,P.conj().t()))).real/np.sqrt(2) else: return torch.norm(Px - Py)/np.sqrt(2)
def adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[int], has_sparse_grad: bool, *, lr: float, weight_decay: float, lr_decay: float, eps: float): r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. """ if weight_decay != 0: if has_sparse_grad: raise RuntimeError( "weight_decay option is not compatible with sparse gradients") torch._foreach_add_(grads, params, alpha=weight_decay) minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps] if has_sparse_grad: # sparse is not supported by multi_tensor. Fall back to optim.adagrad # implementation for sparse gradients for i, (param, grad, state_sum, step) in enumerate(zip(params, grads, state_sums, state_steps)): grad = grad.coalesce( ) # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = grad.size() state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) std_sparse = state_sum.sparse_mask(grad) std_sparse_values = std_sparse._values().sqrt_().add_(eps) param.add_( _make_sparse(grad, grad_indices, grad_values / std_sparse_values), alpha=minus_clr[i], ) else: grads = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in grads ] state_sums = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums ] torch._foreach_addcmul_(state_sums, grads, grads, value=1) std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps) toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std) toAdd = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(toAdd) ] torch._foreach_add_(params, toAdd) state_sums = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(state_sums) ]
def _multi_tensor_adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[Tensor], *, lr: float, weight_decay: float, lr_decay: float, eps: float, has_sparse_grad: bool): # Foreach functions will throw errors if given empty lists if len(params) == 0: return if has_sparse_grad is None: has_sparse_grad = any([grad.is_sparse for grad in grads]) if has_sparse_grad: return _single_tensor_adagrad(params, grads, state_sums, state_steps, lr=lr, weight_decay=weight_decay, lr_decay=lr_decay, eps=eps, has_sparse_grad=has_sparse_grad) # Update steps torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps] grads = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in grads ] state_sums = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums ] torch._foreach_addcmul_(state_sums, grads, grads, value=1) std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps) toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std) toAdd = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(toAdd) ] torch._foreach_add_(params, toAdd) state_sums = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(state_sums) ]
def detect_complex(x): if type(x) == list: return any(type(v) == complex for v in x) elif type(x) == torch.Tensor: return torch.is_complex(x) else: return type(x) == complex
def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and have a device type of ``'cpu'``. Returns: torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using polynomials of order ``self.order``, along axis specified in ``self.axis``. """ # Make input a real tensor on the CPU x = torch.as_tensor( x, device=x.device if isinstance(x, torch.Tensor) else None) if torch.is_complex(x): raise ValueError("x must be real.") x = x.to(dtype=torch.float) if (self.axis < 0) or (self.axis > len(x.shape) - 1): raise ValueError("Invalid axis for shape of x.") # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs, # while the other kernels will be set to [1]. n_spatial_dims = len(x.shape) - 2 spatial_processing_axis = self.axis - 2 new_dims_before = spatial_processing_axis new_dims_after = n_spatial_dims - spatial_processing_axis - 1 kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)] for _ in range(new_dims_before): kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype)) for _ in range(new_dims_after): kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype)) return separable_filtering(x, kernel_list, mode=self.mode)
def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, (C,) F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor) and ( is_torch_1_9_plus and not torch.is_complex(input)): raise TypeError("Only support complex tensors for stft decoder") bs = input.size(0) if input.dim() == 4: multi_channel = True # input: (Batch, T, C, F) -> (Batch * C, T, F) input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) else: multi_channel = False wav, wav_lens = self.stft.inverse(input, ilens) if multi_channel: # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) return wav, wav_lens
def conj_complex_mult(val1: Tensor, val2: Tensor, dim: int = -1) -> Tensor: """Complex multiplication, conjugating second input. Args: val1: A tensor to be multiplied. val2: A second tensor to be conjugated then multiplied. dim: An integer indicating the complex dimension (for real inputs only). Returns: ``val3 = val1 * conj(val2)``, where * executes complex multiplication. """ if not val1.dtype == val2.dtype: raise ValueError("val1 has different dtype than val2.") if torch.is_complex(val1): val3 = val1 * val2.conj() else: if not val1.shape[dim] == val2.shape[dim] == 2: raise ValueError( "Real input does not have dimension size 2 at dim.") real_a = val1.select(dim, 0) imag_a = val1.select(dim, 1) real_b = val2.select(dim, 0) imag_b = val2.select(dim, 1) val3 = torch.stack((real_a * real_b + imag_a * imag_b, imag_a * real_b - real_a * imag_b), dim) return val3
def forward(self, P, G): D = P.dim() if self.axis is None: axis = list(range(1, D)) if D > 2 else list(range(0, D)) else: axis = self.axis axis = [a + D if a < 0 else a for a in axis] caxis = self.caxis if th.is_complex(P): caxis = None if caxis is not None: caxis = self.caxis + D if self.caxis < 0 else self.caxis if caxis != D - 1: newshape = list(range(0, caxis)) + list(range(caxis + 1, D)) + [caxis] P = P.permute(newshape) G = G.permute(newshape) axis = [a if a < caxis else a - 1 for a in axis] P = P[..., 0] + 1j * P[..., 1] G = G[..., 0] + 1j * G[..., 1] if self.norm in ['max', 'MAX', 'Max']: maxv = G.abs().max() + 1e-16 P = P / maxv G = G / maxv for a in axis: P = th.fft.fft(P, n=None, dim=a) G = th.fft.fft(G, n=None, dim=a) P, G = P.angle(), G.angle() return self.lossfn(P, G)
def forward(self, X): if th.is_complex(X): X = X.abs() elif (self.caxis is None) or X.shape[-1] == 2: X = X.pow(2).sum(axis=-1, keepdims=True).sqrt() else: X = X.pow(2).sum(axis=self.caxis, keepdims=True).sqrt() if X.dtype is not th.float32 or th.double: X = X.to(th.float32) if self.axis is None: D = X.dim() axis = list(range(1, D)) if D > 2 else list(range(0, D)) else: axis = self.axis X = th.mean(X.pow(self.p), axis=axis).pow(1. / self.p) if self.reduction == 'mean': F = th.mean(X) if self.reduction == 'sum': F = th.sum(X) return th.log(F)
def perform_filter_operation( Y: Union[torch.Tensor, ComplexTensor], filter_matrix_conj: Union[torch.Tensor, ComplexTensor], taps, delay, ) -> Union[torch.Tensor, ComplexTensor]: """perform_filter_operation Args: Y : Complex-valued STFT signal of shape (F, C, T) filter Matrix (F, taps, C, C) """ if isinstance(Y, ComplexTensor): complex_module = FC pad_func = FC.pad elif is_torch_1_9_plus and torch.is_complex(Y): complex_module = torch pad_func = F.pad else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support.") T = Y.size(-1) # Y_tilde: (taps, F, C, T) Y_tilde = complex_module.stack( [ pad_func(Y[:, :, :T - delay - i], (delay + i, 0), mode="constant", value=0) for i in range(taps) ], dim=0, ) reverb_tail = complex_module.einsum("fpde,pfdt->fet", (filter_matrix_conj, Y_tilde)) return Y - reverb_tail
def forward(self, X): if th.is_complex(X): X = (X * X.conj()).real elif X.size(-1) == 2: X = X.pow(2).sum(axis=-1) D = X.dim() axis = list(range(1, D)) if X.dtype is not th.float32 or th.double: X = X.to(th.float32) if self.mode in ['way1', 'WAY1']: Xmean = X.mean(axis=axis, keepdims=True) C = (X - Xmean).pow(2).mean(axis=axis, keepdims=True).sqrt() / (Xmean + EPS) if self.mode in ['way2', 'WAY2']: C = X.mean(axis=axis, keepdims=True) / ( (X.sqrt().mean(axis=axis, keepdims=True)).pow(2) + EPS) if self.reduction == 'mean': C = th.mean(C) if self.reduction == 'sum': C = th.sum(C) return -C
def forward(self, X): if th.is_complex(X): X = (X * X.conj()).real elif (self.caxis is None) or X.shape[-1] == 2: X = th.sum(X.pow(2), axis=-1, keepdims=True) else: X = th.sum(X.pow(2), axis=self.caxis, keepdims=True) if self.axis is None: D = X.dim() axis = list(range(1, D)) if D > 2 else list(range(0, D)) else: axis = self.axis P = th.sum(X, axis, keepdims=True) p = X / (P + EPS) if self.mode in ['Shannon', 'shannon', 'SHANNON']: S = -th.sum(p * th.log2(p + EPS), axis) if self.mode in ['Natural', 'natural', 'NATURAL']: S = -th.sum(p * th.log(p + EPS), axis) if self.reduction == 'mean': S = th.mean(S) if self.reduction == 'sum': S = th.sum(S) return S
def forward(self, P, G): D = P.dim() caxis = self.caxis if th.is_complex(P): caxis = None if caxis is not None: caxis = self.caxis + D if self.caxis < 0 else self.caxis if caxis != D - 1: newshape = list(range(0, caxis)) + list(range(caxis + 1, D)) + [caxis] P = P.permute(newshape) G = G.permute(newshape) if P.shape[-1] == 2: P = P[..., 0] + 1j * P[..., 1] if G.shape[-1] == 2: G = G[..., 0] + 1j * G[..., 1] axis = list(range(1, P.dim())) if self.norm in ['max', 'MAX', 'Max']: maxv = G.abs().max() + 1e-16 P = P / maxv G = G / maxv F = th.mean((P - G).abs(), axis=axis) if self.reduction == 'mean': F = th.mean(F) if self.reduction == 'sum': F = th.sum(F) return F
def frobenius(X, p=2, reduction='mean'): r"""frobenius norm .. math:: \|\bm X\|_p^p = (\sum{x^p})^{1/p} """ if th.is_complex(X): X = ((X * X.conj()).real).sqrt() elif X.size(-1) == 2: X = X.pow(2).sum(axis=-1).sqrt() if X.dtype is not th.float32 or th.double: X = X.to(th.float32) D = X.dim() dim = list(range(1, D)) X = th.mean(X.pow(p), axis=dim).pow(1. / p) if reduction == 'mean': F = th.mean(X) if reduction == 'sum': F = th.sum(X) return F
def ifft(x, n=None, axis=0, norm="backward", shift=False): """IFFT in torchsar IFFT in torchsar, since ifft in torch only supports complex-complex transformation, for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))), also you can use torch's rifft. Parameters ---------- x : {torch array} both complex and real representation are supported. Since torch does not support complex array, when :attr:`x` is complex, we will change the representation in real formation(last dimension is 2, real, imag), after IFFT, it will be change back. n : int, optional number of ifft points (the default is None --> equals to signal dimension) axis : int, optional axis of ifft (the default is 0, which the first dimension) norm : bool, optional Normalization mode. For the backward transform (ifft()), these correspond to: - "forward" - no normalization - "backward" - normalize by ``1/n`` (default) - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal) shift : bool, optional shift the zero frequency to center (the default is False) Returns ------- y : {torch array} ifft results torch array with the same type as :attr:`x` Raises ------ ValueError nfft is small than signal dimension. """ if norm is None: norm = 'backward' if (x.size(-1) == 2) and (not th.is_complex(x)): realflag = True x = th.view_as_complex(x) if axis < 0: axis += 1 else: realflag = False if shift: y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis), n=n, dim=axis, norm=norm), dim=axis) else: y = thfft.ifft(x, n=n, dim=axis, norm=norm) if realflag: y = th.view_as_real(y) return y
def forward(self, X): if th.is_complex(X): X = (X * X.conj()).real elif (self.caxis is None) or X.shape[-1] == 2: X = th.sum(X.pow(2), axis=-1, keepdims=True) else: X = th.sum(X.pow(2), axis=self.caxis, keepdims=True) if self.axis is None: D = X.dim() axis = list(range(1, D)) if D > 2 else list(range(0, D)) else: axis = self.axis if X.dtype is not th.float32 or th.double: X = X.to(th.float32) if self.mode in ['way1', 'WAY1']: Xmean = X.mean(axis=axis, keepdims=True) C = (X - Xmean).pow(2).mean(axis=axis, keepdims=True).sqrt() / (Xmean + EPS) if self.mode in ['way2', 'WAY2']: C = X.mean(axis=axis, keepdims=True) / ( (X.sqrt().mean(axis=axis, keepdims=True)).pow(2) + EPS) if self.reduction == 'mean': C = th.mean(C) if self.reduction == 'sum': C = th.sum(C) return -C
def getLamdaGaplist(lambdas: torch.Tensor): """ Calculate the gaps between lambda values. """ if torch.is_complex(lambdas): lambdas = torch.real(lambdas) return lambdas[1:] - lambdas[:-1]
def forward(self, X): if th.is_complex(X): X = (X * X.conj()).real elif X.size(-1) == 2: X = th.sum(X.pow(2), axis=-1) if X.dim() == 2: axis = (0, 1) if X.dim() == 3: axis = (1, 2) if X.dim() == 4: axis = (1, 2, 3) P = th.sum(X, axis, keepdims=True) p = X / (P + EPS) if self.mode in ['Shannon', 'shannon', 'SHANNON']: S = -th.sum(p * th.log2(p + EPS), axis) if self.mode in ['Natural', 'natural', 'NATURAL']: S = -th.sum(p * th.log(p + EPS), axis) if self.reduction == 'mean': S = th.mean(S) if self.reduction == 'sum': S = th.sum(S) return S
def negative_norm( x: torch.FloatTensor, p: Union[str, int] = 2, power_norm: bool = False, ) -> torch.FloatTensor: """Evaluate negative norm of a vector. :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim) The vectors. :param p: The p for the norm. cf. torch.norm. :param power_norm: Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119 :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ if power_norm: assert not isinstance(p, str) return -(x.abs()**p).sum(dim=-1) if torch.is_complex(x): assert not isinstance(p, str) # workaround for complex numbers: manually compute norm return -(x.abs()**p).sum(dim=-1)**(1 / p) return -x.norm(p=p, dim=-1)
def forward(self, X): D = X.dim() if self.axis is None: axis = list(range(1, D)) if D > 2 else list(range(0, D)) else: axis = self.axis axis = [a + D if a < 0 else a for a in axis] if th.is_complex(X): caxis = None else: caxis = self.caxis + D if self.caxis < 0 else self.caxis if caxis != D - 1: newshape = list(range(0, caxis)) + list(range(caxis + 1, D)) + [caxis] X = X.permute(newshape) axis = [a if a < caxis else a - 1 for a in axis] X = X[..., 0] + 1j * X[..., 1] for a in axis: X = th.fft.fft(X, n=None, dim=a) X = X.abs() S = th.sum(th.log2(1 + X / self.p), axis) if self.reduction == 'mean': S = th.mean(S) if self.reduction == 'sum': S = th.sum(S) return S
def signal_framing( signal: Union[torch.Tensor, ComplexTensor], frame_length: int, frame_step: int, pad_value=0, ) -> Union[torch.Tensor, ComplexTensor]: """Expands signal into frames of frame_length. Args: signal : (B * F, D, T) Returns: torch.Tensor: (B * F, D, T, W) """ if isinstance(signal, ComplexTensor): real = signal_framing(signal.real, frame_length, frame_step, pad_value) imag = signal_framing(signal.imag, frame_length, frame_step, pad_value) return ComplexTensor(real, imag) elif is_torch_1_9_plus and torch.is_complex(signal): real = signal_framing(signal.real, frame_length, frame_step, pad_value) imag = signal_framing(signal.imag, frame_length, frame_step, pad_value) return torch.complex(real, imag) signal = F.pad(signal, (0, frame_length - 1), "constant", pad_value) indices = sum( [ list(range(i, i + frame_length)) for i in range(0, signal.size(-1) - frame_length + 1, frame_step) ], [], ) signal = signal[..., indices].view(*signal.size()[:-1], -1, frame_length) return signal
def batch_fft(data, normalize=False): """ Compute fourier transform of batch. Args: data: input tensor, (NxHxW) Returns: Batch fourier transform of input data. """ dim = data.ndim - 1 # subtract one for batch dimension if dim != 2: raise AttributeError(f'Data must be 2d but it is {dim}d.') dims = tuple(range(1, dim + 1)) # add one for batch dimension if normalize: norm = 'ortho' else: norm = 'backward' if not torch.is_complex(data): data = torch.complex(data, torch.zeros_like(data)) freq = fftn(data, dim=dims, norm=norm) return freq
def forward(self, X, w=None): r"""[summary] [description] Parameters ---------- X : {[type]} After fft in azimuth w : {[type]}, optional [description] (the default is None, which [default_description]) Returns ------- [type] [description] """ if th.is_complex(X): X = X.abs() elif X.shape[-1] == 2: X = th.view_as_complex(X) X = X.abs() if w is None: wshape = [1] * (X.dim()) wshape[-2] = X.size(-2) w = th.ones(wshape, device=X.device, dtype=X.dtype) fv = th.sum((th.sum(w * X, axis=-2)).pow(self.p), axis=-1) if self.reduction == 'mean': C = th.mean(fv) if self.reduction == 'sum': C = th.sum(fv) return C
def _multi_tensor_adamax(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_infs: List[Tensor], state_steps: List[Tensor], *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool): if len(params) == 0: return if maximize: grads = torch._foreach_neg(grads) params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs] # Update steps torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) # Update biased first moment estimate. torch._foreach_mul_(exp_avgs, beta1) torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # Update the exponentially weighted infinity norm. torch._foreach_mul_(exp_infs, beta2) for exp_inf, grad in zip(exp_infs, grads): norm_buf = torch.cat([ exp_inf.unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0) ], 0) torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) bias_corrections = [1 - beta1 ** step.item() for step in state_steps] clr = [-1 * (lr / bias_correction) for bias_correction in bias_corrections] torch._foreach_addcdiv_(params, exp_avgs, exp_infs, clr)