def spmat_interp_adjoint( data: Tensor, interp_mats: Union[Tensor, Tuple[Tensor, Tensor]], grid_size: Tensor, ) -> Tensor: """Sparse matrix interpolation adjoint backend.""" if not isinstance(interp_mats, tuple): raise TypeError("interp_mats must be 2-tuple of (real_mat, imag_mat.") coef_mat_real, coef_mat_imag = interp_mats batch_size, num_coils = data.shape[:2] # sparse matrix multiply requires real data = torch.view_as_real(data) output_size = [batch_size, num_coils] + grid_size.tolist() # we have to do these transposes because torch.mm requires first to be spmatrix real_kdat = data.select(-1, 0).view(-1, data.shape[-2]).t().contiguous() imag_kdat = data.select(-1, 1).view(-1, data.shape[-2]).t().contiguous() coef_mat_real = coef_mat_real.t() coef_mat_imag = coef_mat_imag.t() # apply multiplies with complex conjugate image = torch.stack( [ (torch.mm(coef_mat_real, real_kdat) + torch.mm(coef_mat_imag, imag_kdat)).t(), (torch.mm(coef_mat_real, imag_kdat) - torch.mm(coef_mat_imag, real_kdat)).t(), ], dim=-1, ) return torch.view_as_complex(image).reshape(*output_size)
def dct(x: torch.Tensor, norm: str = "None"): """ Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last dimension """ x_shape = x.shape N = torch.tensor(x_shape[-1], device=x.device) x = x.contiguous().view(-1, N) v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) # if TORCH_VER >= 1.8: Vc = torch.view_as_real(torch.fft.fft(v)) # else: # Vc = torch.rfft(v, 1, onesided=False) k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / ( 2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i if norm == "ortho": V[:, 0] /= torch.sqrt(4 * N) # torch.sqrt(N) * 2 # np.sqrt(N) * 2 V[:, 1:] /= torch.sqrt(2 * N) # torch.sqrt(N) * 2 # np.sqrt(N / 2) * 2 V = 2 * V.view(x_shape) return V
def forward(self, x: Tensor) -> Tensor: """STFT forward path Args: x (Tensor): audio waveform of shape (nb_samples, nb_channels, nb_timesteps) Returns: STFT (Tensor): complex stft of shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2) last axis is stacked real and imaginary """ shape = x.size() nb_samples, nb_channels, nb_timesteps = shape # pack batch x = x.view(-1, shape[-1]) complex_stft = torch.stft( x, n_fft=self.n_fft, hop_length=self.n_hop, window=self.window, center=self.center, normalized=False, onesided=True, pad_mode="reflect", return_complex=True, ) stft_f = torch.view_as_real(complex_stft) # unpack batch stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:]) return stft_f
def grid_1d_from_2d(self, x, dx, vis, y): nvis, N = x.shape W = self.config['W'] Dx = W * dx xref = torch.ceil((x - 0.5 * W * dx) / dx) * dx xndx = torch.arange(W, dtype=xref.dtype, device=xref.device) xg = xref + xndx * dx gcf_val = _gcf_kaiser(xg - x, Dx, self.beta).float() # Batch mm unsupported for complex yet for torch CUDA vis_ri = torch.view_as_real(vis) vis_r = vis_ri[:, :, -2] vis_i = vis_ri[:, :, -1] vis2_r = torch.matmul(vis_r[:, :, None], gcf_val[:, None, :]) vis2_i = torch.matmul(vis_i[:, :, None], gcf_val[:, None, :]) vis2 = torch.view_as_complex(torch.stack([vis2_r, vis2_i], dim=-1)) # vis2 = torch.matmul(vis[:, :, None], gcf_val[:, None, :]) vis2 = vis2.reshape(nvis, -1) x2 = xg.repeat(1, N) y2 = torch.repeat_interleave(y, W, axis=-1) return x2, vis2, y2
def _handle_complex(tensor): """ Returns a real view of a tensor if complex dtype else just the tensor need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule """ return torch.view_as_real(tensor) if not isinstance(tensor, torch.nn.UninitializedParameter) and tensor.is_complex() else tensor
def _test_all_gather(self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None): for dest in group: tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype) tensors = [ _build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group ] expected_tensors = [ _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group ] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] if tensors[0].dtype == torch.complex64: tensor_shapes = [torch.view_as_real(tensors[0]).shape] else: tensor_shapes = [tensors[0].shape] allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None) allgather(tensors, tensor, group=group_id, async_op=False) for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2)
def ifft(x, n=None, axis=0, norm="backward", shift=False): """IFFT in torchsar IFFT in torchsar, since ifft in torch only supports complex-complex transformation, for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))), also you can use torch's rifft. Parameters ---------- x : {torch array} both complex and real representation are supported. Since torch does not support complex array, when :attr:`x` is complex, we will change the representation in real formation(last dimension is 2, real, imag), after IFFT, it will be change back. n : int, optional number of ifft points (the default is None --> equals to signal dimension) axis : int, optional axis of ifft (the default is 0, which the first dimension) norm : bool, optional Normalization mode. For the backward transform (ifft()), these correspond to: - "forward" - no normalization - "backward" - normalize by ``1/n`` (default) - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal) shift : bool, optional shift the zero frequency to center (the default is False) Returns ------- y : {torch array} ifft results torch array with the same type as :attr:`x` Raises ------ ValueError nfft is small than signal dimension. """ if norm is None: norm = 'backward' if (x.size(-1) == 2) and (not th.is_complex(x)): realflag = True x = th.view_as_complex(x) if axis < 0: axis += 1 else: realflag = False if shift: y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis), n=n, dim=axis, norm=norm), dim=axis) else: y = thfft.ifft(x, n=n, dim=axis, norm=norm) if realflag: y = th.view_as_real(y) return y
def bound_complex_mask(mask: ComplexTensor, bound_type="tanh"): r"""Bound a complex mask, as proposed in [1], section 3.2. Valid bound types, for a complex mask $M = |M| ⋅ e^{i φ(M)}$: - Unbounded ("UBD"): :math:`M_{\mathrm{UBD}} = M` - Sigmoid ("BDSS"): :math:`M_{\mathrm{BDSS}} = σ(|M|) e^{i σ(φ(M))}` - Tanh ("BDT"): :math:`M_{\mathrm{BDT}} = \mathrm{tanh}(|M|) e^{i φ(M)}` Args: bound_type (str or None): The type of bound to use, either of "tanh"/"bdt" (default), "sigmoid"/"bdss" or None/"bdt". References - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net", Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107 """ if bound_type in {"BDSS", "sigmoid"}: return on_reim(torch.sigmoid)(mask) elif bound_type in {"BDT", "tanh", "UBD", None}: mask_mag, mask_phase = torchaudio.functional.magphase(torch.view_as_real(mask)) if bound_type in {"BDT", "tanh"}: mask_mag_bounded = torch.tanh(mask_mag) else: mask_mag_bounded = mask_mag return torch_complex_from_magphase(mask_mag_bounded, mask_phase) else: raise ValueError(f"Unknown mask bound {bound_type}")
def test_batch_TimeStretch(self, test_pseudo_complex): rate = 2 num_freq = 1025 num_frames = 400 spec = torch.randn(num_freq, num_frames, dtype=torch.complex64) pattern = [3, 1, 1, 1] if test_pseudo_complex: spec = torch.view_as_real(spec) pattern += [1] # Single then transform then batch expected = torchaudio.transforms.TimeStretch( fixed_rate=rate, n_freq=num_freq, hop_length=512, )(spec).repeat(*pattern) # Batch then transform computed = torchaudio.transforms.TimeStretch( fixed_rate=rate, n_freq=num_freq, hop_length=512, )(spec.repeat(*pattern)) self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data: Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. norm: Normalization mode. See ``torch.fft.ifft``. Returns: The IFFT of the input. """ if not data.shape[-1] == 2: raise ValueError("Tensor does not have separate complex dim.") data = ifftshift(data, dim=[-3, -2]) data = torch.view_as_real( torch.fft.ifftn( # type: ignore torch.view_as_complex(data), dim=(-2, -1), norm=norm)) data = fftshift(data, dim=[-3, -2]) return data
def ifft2(data): assert data.shape[-1] == 2 data = ifftshift(data, axes=(-3, -2)) data = torch.view_as_complex(data) data = torch.fft.ifftn(data, dim=(-2, -1), norm='ortho') data = torch.view_as_real(data) return data
def get_spec(self, audio): with torch.cuda.amp.autocast(enabled=False): spec = self.stft(audio) if spec.dtype in [torch.cfloat, torch.cdouble]: spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) return spec
def forward(self,input): assert(input.shape[1] == input.shape[2]) # padding input = self.pad(torch.view_as_complex(input)) # scaling input = self.deformation(input) #self.deformation(input) # to Fourier domain input = complex_ifftshift(input) input = complex_fft(input, 2) input = complex_fftshift(input) # input = torch.view_as_real(complex_fftshift(input)) # Zernike layers in the Fourier plane input = self.zernike_ft(input) # to direct domain # input = torch.view_as_complex(input) input = complex_ifftshift(input) input = complex_ifft(input, 2) input = complex_fftshift(input) # input = torch.view_as_real(input) # Zernike layers in the direct plane input = self.zernike_direct(input) # Crop at the center (because of coeff) input = crop_center(input,self.nxy) return torch.view_as_real(input)
def test_phase_vocoder(self, rate, test_pseudo_complex): hop_length = 256 num_freq = 1025 num_frames = 400 torch.random.manual_seed(42) # Due to cummulative sum, numerical error in using torch.float32 will # result in bottom right values of the stretched sectrogram to not # match with librosa. spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128) phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, device=self.device, dtype=torch.float64)[..., None] stretched = F.phase_vocoder( torch.view_as_real(spec) if test_pseudo_complex else spec, rate=rate, phase_advance=phase_advance) expected_stretched = librosa.phase_vocoder(spec.cpu().numpy(), rate=rate, hop_length=hop_length) self.assertEqual( torch.view_as_complex(stretched) if test_pseudo_complex else stretched, torch.from_numpy(expected_stretched))
def spmat_interp(image: Tensor, interp_mats: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tensor: """Sparse matrix interpolation backend.""" if not isinstance(interp_mats, tuple): raise TypeError("interp_mats must be 2-tuple of (real_mat, imag_mat.") coef_mat_real, coef_mat_imag = interp_mats batch_size, num_coils = image.shape[:2] # sparse matrix multiply requires real image = torch.view_as_real(image) output_size = [batch_size, num_coils, -1] # we have to do these transposes because torch.mm requires first to be spmatrix image = image.reshape(batch_size * num_coils, -1, 2) real_griddat = image.select(-1, 0).t().contiguous() imag_griddat = image.select(-1, 1).t().contiguous() # apply multiplies kdat = torch.stack( [ (torch.mm(coef_mat_real, real_griddat) - torch.mm(coef_mat_imag, imag_griddat)).t(), (torch.mm(coef_mat_real, imag_griddat) + torch.mm(coef_mat_imag, real_griddat)).t(), ], dim=-1, ) return torch.view_as_complex(kdat).reshape(*output_size)
def stft(x, fft_size, hop_size, win_length, window): """Perform STFT and convert to magnitude spectrogram. Args: x (Tensor): Input signal tensor (B, T). fft_size (int): FFT size. hop_size (int): Hop size. win_length (int): Window length. window (str): Window function type. Returns: Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). """ if is_torchver_higher18: ## For future pytorch release (1.8<=), they strongly prefer to use return_complex=True x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) x_stft = torch.view_as_real(x_stft) else: x_stft = torch.stft(x, fft_size, hop_size, win_length, window) real = x_stft[..., 0] imag = x_stft[..., 1] # NOTE(kan-bayashi): clamp is needed to avoid nan or inf return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
def dct(x, norm=None): """ Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last dimension """ x_shape = x.shape N = x_shape[-1] x = x.contiguous().view(-1, N) v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / ( 2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i if norm == 'ortho': V[:, 0] /= np.sqrt(N) * 2 V[:, 1:] /= np.sqrt(N / 2) * 2 V = 2 * V.view(*x_shape) return V
def test_phase_vocoder_shape(self, rate, test_pseudo_complex): """Verify the output shape of phase vocoder""" hop_length = 256 num_freq = 1025 num_frames = 400 batch_size = 2 torch.random.manual_seed(42) spec = torch.randn(batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) if test_pseudo_complex: spec = torch.view_as_real(spec) phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, dtype=self.real_dtype, device=self.device)[..., None] spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) assert spec.dim() == spec_stretch.dim() expected_shape = torch.Size( [batch_size, num_freq, int(np.ceil(num_frames / rate))]) output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape assert output_shape == expected_shape
def test_view_as_real(self): x = torch.randn(4, dtype=torch.complex64) y = torch.view_as_real(x) m = MetaConverter()(y) self.assertEqual(m.shape, y.shape) self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.dtype, y.dtype)
def test_timestretch_non_zero(self, rate, test_pseudo_complex): """Verify that ``T.TimeStretch`` does not fail if it's not close to 0 ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability for cases where input is not zero. As tested above, when spectrogram contains values close to zero, the gradients are unstable and gradcheck fails. In this test, we generate spectrogram from random signal, then we push the points around zero away from the origin. This process does not reflect the real use-case, and it is not practical for users, but this helps us understand to what degree the function is differentiable and when not. """ n_fft = 16 transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate) waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2) spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) # 1e-3 is too small (on CPU) epsilon = 1e-2 too_close = spectrogram.abs() < epsilon spectrogram[too_close] = epsilon * spectrogram[ too_close] / spectrogram[too_close].abs() if test_pseudo_complex: spectrogram = torch.view_as_real(spectrogram) self.assert_grad(transform, [spectrogram])
def _multi_tensor_adamax(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_infs: List[Tensor], state_steps: List[Tensor], *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool): if len(params) == 0: return if maximize: grads = torch._foreach_neg(grads) params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs] # Update steps torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) # Update biased first moment estimate. torch._foreach_mul_(exp_avgs, beta1) torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # Update the exponentially weighted infinity norm. torch._foreach_mul_(exp_infs, beta2) for exp_inf, grad in zip(exp_infs, grads): norm_buf = torch.cat([ exp_inf.unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0) ], 0) torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) bias_corrections = [1 - beta1 ** step.item() for step in state_steps] clr = [-1 * (lr / bias_correction) for bias_correction in bias_corrections] torch._foreach_addcdiv_(params, exp_avgs, exp_infs, clr)
def sdsp( x: Tensor, filtr: Tensor, value_range: float = 1., sigma_c: float = 0.001, sigma_d: float = 145., ) -> Tensor: r"""Detects salient regions from :math:`x`. Args: x: An input tensor, :math:`(N, 3, H, W)`. filtr: The frequency domain filter, :math:`(H, W)`. value_range: The value range :math:`L` of the input (usually `1.` or `255`). Note: For the remaining arguments, refer to [Zhang2013]_. Returns: The visual saliency tensor, :math:`(N, H, W)`. Example: >>> x = torch.rand(5, 3, 256, 256) >>> filtr = sdsp_filter(x) >>> vs = sdsp(x, filtr) >>> vs.size() torch.Size([5, 256, 256]) """ x_lab = xyz_to_lab(rgb_to_xyz(x, value_range)) # Frequency prior x_f = fft.ifft2(fft.fft2(x_lab) * filtr) x_f = cx.real(torch.view_as_real(x_f)) s_f = l2_norm(x_f, dims=[1]) # Color prior x_ab = x_lab[:, 1:] lo, _ = x_ab.flatten(-2).min(dim=-1) up, _ = x_ab.flatten(-2).max(dim=-1) lo = lo.view(lo.shape + (1, 1)) up = up.view(lo.shape) span = torch.where(up > lo, up - lo, torch.tensor(1.).to(lo)) x_ab = (x_ab - lo) / span s_c = 1. - torch.exp(-torch.sum(x_ab**2, dim=1) / sigma_c**2) # Location prior a, b = [torch.arange(n).to(x) - (n - 1) / 2 for n in x.shape[-2:]] s_d = torch.exp(-(a[None, :]**2 + b[:, None]**2) / sigma_d**2) # Visual saliency vs = s_f * s_c * s_d return vs
def s2_fft(x, for_grad=False, b_out=None): ''' :param x: [..., beta, alpha, complex] :return: [l * m, ..., complex] ''' assert x.size(-1) == 2 b_in = x.size(-2) // 2 assert x.size(-2) == 2 * b_in assert x.size(-3) == 2 * b_in if b_out is None: b_out = b_in assert b_out <= b_in batch_size = x.size()[:-3] x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex] ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) x = torch.view_as_real(torch.fft.fft( torch.view_as_complex(x))) # [batch, beta, m, complex] output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] else: for l in range(b_out): s = slice(l**2, l**2 + 2 * l + 1) xx = torch.cat( (x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2) return output
def A_realspace(r, t, psi, out): """ :param r: K x 2 :param t: BB x NY x NX :param psi: B x K x MY x MX :param out: K x MY x MX :return: """ gpu = cuda.get_current_device() threadsperblock = gpu.MAX_THREADS_PER_BLOCK blockspergrid = m.ceil(np.prod(out.shape) / threadsperblock) # print(r.shape,t.shape,psi.shape,out.shape) A_realspace_kernel[blockspergrid, threadsperblock](r, th.view_as_real(t), th.view_as_real(psi), th.view_as_real(out)) return out
def events(self) -> Tuple[torch.Tensor, torch.Tensor]: r""" x* """ x = self.postprocess(self.x_star) x = torch.tensor(x) x = torch.view_as_real(x) return None, x[None]
def _single_tensor_adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[Tensor], *, lr: float, weight_decay: float, lr_decay: float, eps: float, has_sparse_grad: bool): for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps): # update step step_t += 1 step = step_t.item() if weight_decay != 0: if grad.is_sparse: raise RuntimeError( "weight_decay option is not compatible with sparse gradients" ) grad = grad.add(param, alpha=weight_decay) clr = lr / (1 + (step - 1) * lr_decay) if grad.is_sparse: grad = grad.coalesce( ) # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = grad.size() state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) std = state_sum.sparse_mask(grad) std_values = std._values().sqrt_().add_(eps) param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) else: is_complex = torch.is_complex(param) if is_complex: grad = torch.view_as_real(grad) state_sum = torch.view_as_real(state_sum) param = torch.view_as_real(param) state_sum.addcmul_(grad, grad, value=1) std = state_sum.sqrt().add_(eps) param.addcdiv_(grad, std, value=-clr) if is_complex: param = torch.view_as_complex(param) state_sum = torch.view_as_complex(state_sum)
def irfft(input, n=None): if torch.is_complex(input): input = torch.view_as_real(input) else: input = torch.nn.functional.pad(input[..., None], (0, 1)) if n is None: n = 2 * (input.size(-1) - 1) return torch.irfft(input, 1, signal_sizes=(n,))
def _matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor: vals, vecs = torch.eig(matrix, eigenvectors=True) vals = torch.view_as_complex(vals.contiguous()) vals_pow = vals.pow(p) vals_pow = torch.view_as_real(vals_pow)[:, 0] matrix_pow = torch.matmul( vecs, torch.matmul(torch.diag(vals_pow), torch.inverse(vecs))) return matrix_pow
def test_InverseSpectrogram_pseudocomplex(self): tensor = common_utils.get_whitenoise(sample_rate=8000) spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) spectrogram = torch.view_as_real(spectrogram) self._assert_consistency( T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
def adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[int], *, lr: float, weight_decay: float, lr_decay: float, eps: float): r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. """ for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps): if weight_decay != 0: if grad.is_sparse: raise RuntimeError( "weight_decay option is not compatible with sparse gradients" ) grad = grad.add(param, alpha=weight_decay) clr = lr / (1 + (step - 1) * lr_decay) if grad.is_sparse: grad = grad.coalesce( ) # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = grad.size() state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) std = state_sum.sparse_mask(grad) std_values = std._values().sqrt_().add_(eps) param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) else: is_complex = torch.is_complex(param) if is_complex: grad = torch.view_as_real(grad) state_sum = torch.view_as_real(state_sum) param = torch.view_as_real(param) state_sum.addcmul_(grad, grad, value=1) std = state_sum.sqrt().add_(eps) param.addcdiv_(grad, std, value=-clr) if is_complex: param = torch.view_as_complex(param) state_sum = torch.view_as_complex(state_sum)