def convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: """ Computes the 1-d convolution of signal by kernel using FFTs. Both signal and kernel must be 1-dimensional. :param torch.Tensor signal: A signal to convolve. :param torch.Tensor kernel: A convolution kernel. :param str mode: One of: 'full', 'valid', 'same'. :return: torch.Tensor Convolution of signal with kernel. Returns the full convolution, i.e., the output tensor will have size m + n - 1, where m is the length of the signal and n is the length of the kernel. """ assert (signal.ndim == 1 and kernel.ndim == 1), "signal and kernel must be 1-dimensional" m = signal.size(-1) n = kernel.size(-1) # Compute convolution using fft. padded_size = m + n - 1 # Round up for cheaper fft. fast_ftt_size = next_fast_len(padded_size) f_signal = rfft(signal, n=fast_ftt_size) f_kernel = rfft(kernel, n=fast_ftt_size) f_result = f_signal * f_kernel result = irfft(f_result, n=fast_ftt_size) return result[:padded_size]
def fft_convolve(signal, kernel): signal = nn.functional.pad(signal, (0, signal.shape[-1])) kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) output = output[..., output.shape[-1] // 2:] return output
def forward(self, b, a): b_fft = rfft(b, n=self.nfft, dim=1) a_fft = rfft(a, n=self.nfft, dim=1) yabs = (torch.abs(b_fft) + self.eps) / (torch.abs(a_fft) + self.eps) # exclude first (DC) and last (Nyquist) bins #y = b_fft[:, 1:-1] / a_fft[:, 1:-1] assert torch.all(torch.isfinite(yabs)) # calculate absolute value of complex tuple #yabs = torch.abs(y) # reintroduce first (DC) and last (Nyquist) bins #ones = torch.ones(a.shape[0], 1) #yabs = torch.cat([ones, yabs, ones], dim=-1) return yabs
def _process_batch(self, batch, target_has_phys=False): (c, f), target = batch if self.use_fft: f = rfft(f) predictions = self.model([c, f]) batch_size = c[-1, -1] + 1 predictions, target_tensor = self._format_target_and_prediction( predictions, c, target, batch_size, target_has_phys) if target_has_phys: if self.SE_only: loss = self.criterion.forward( self.SE_mask * predictions[:, 0, :, :], self.SE_mask * target_tensor[:, self.evaluator.z_index, :, :] ) * self.SE_factor else: loss = self.criterion.forward( predictions[:, 0, :, :], target_tensor[:, self.evaluator.z_index, :, :]) else: if self.SE_only: loss = self.criterion.forward( self.SE_mask * predictions, self.SE_mask * target_tensor) * self.SE_factor else: loss = self.criterion.forward(predictions, target_tensor) loss *= (self.evaluator.nx * self.evaluator.ny * batch_size / c.shape[0]) return loss, predictions, target_tensor, c, f
def __call__(self, arg): """ Apply filter on a (batched) signal. """ N = arg.shape[-1] if not N in self._cached: # cache or print or error if self.strict: print(f"caching sparse operator for size {N}") self.cache(N) # read cache d, w = self._cached[N] if d.dim() == arg.dim(): F_arg = rfft(w * arg) return irfft(d * F_arg) # batched F_arg = rfft(w * arg, dim=1) return irfft(d * F_arg, dim=1)
def init_fft(x, pdim=3): Fx = rfft(x) a = 0.5 * torch.randn([pdim]) _, js = torch.sort(Fx.abs(), descending=True) phi = torch.index_select(Fx, 0, js).angle()[:pdim] w = js[:pdim] return torch.stack([a, phi, w])
def iirfreqz(h, ndft, squared=False, powerfloor=10**-3): """Compute frequency response of an IIR filter.""" assert ndft > h.size(-1), "Incompatible DFT size!" h = F.pad(h, (0, ndft - h.size(-1))) hspec = fft.rfft(h, 1) hspec = (hspec[..., 0]**2 + hspec[..., 1]**2).clamp(min=powerfloor) if squared: return 1 / hspec return 1 / (hspec**.5)
def firfreqz(h, ndft, squared=False): """Compute frequency response of an FIR filter.""" assert ndft > h.size(-1), "Incompatible DFT size!" h = F.pad(h, (0, ndft - h.size(-1))) hspec = fft.rfft(h, 1) hspec = hspec[..., 0]**2 + hspec[..., 1]**2 if squared: return hspec return hspec**.5
def conv1d(signal, kernel, mode='fft_circular', cut=False, cut_lim=150): """ signal M x N kernel N """ kernel_size = int(kernel.shape[-1]) if mode == 'direct': conved = F.conv1d(signal.unsqueeze(1), kernel.flip(0).unsqueeze(0).unsqueeze(0), padding=kernel_size - 1)[:, 0] elif mode == 'fft_circular': conved = irfft(rfft(signal) * rfft(kernel), signal.shape[-1]) if cut: conved = conved[:, cut_lim:kernel_size + cut_lim] return conved
def init(self, x, dev=0.01): Fx = rfft(x) _, js = torch.sort(Fx.abs(), descending=True) n = self.n_modes phi = torch.index_select(Fx, 0, js).angle()[:n] w = js[:n] amp = dev * torch.randn([n]).abs() * torch.sqrt(torch.var(x)) self.phi = Parameter(phi) self.w = Parameter(w.float()) self.amp = Parameter(amp) return self
def dct1(x): """ Discrete Cosine Transform, Type I :param x: the input signal :return: the DCT-I of the signal over the last dimension """ x_shape = x.shape x = x.view(-1, x_shape[-1]) return fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
def __call__(self, image: torch.Tensor, context: ExpressionContext) -> torch.Tensor: std = 15. * torch.Tensor(context(self.std)).to(image.device).reshape( 3, 1) space = fft.rfft(image.reshape(3, -1)) space.real = space.real + torch.randn(space.shape).to( image.device) * std space.imag = space.imag + torch.randn(space.shape).to( image.device) * std return fft.irfft(space).reshape(*image.shape)
def convolve(signal, kernel, mode='full'): """ Computes the 1-d convolution of signal by kernel using FFTs. The two arguments should have the same rightmost dim, but may otherwise be arbitrarily broadcastable. :param torch.Tensor signal: A signal to convolve. :param torch.Tensor kernel: A convolution kernel. :param str mode: One of: 'full', 'valid', 'same'. :return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)`` and ``n = kernel.size(-1)``, the rightmost size of the result will be: ``m + n - 1`` if mode is 'full'; ``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or ``max(m, n)`` if mode is 'same'. :rtype torch.Tensor: """ m = signal.size(-1) n = kernel.size(-1) if mode == 'full': truncate = m + n - 1 elif mode == 'valid': truncate = max(m, n) - min(m, n) + 1 elif mode == 'same': truncate = max(m, n) else: raise ValueError('Unknown mode: {}'.format(mode)) # Compute convolution using fft. padded_size = m + n - 1 # Round up for cheaper fft. fast_ftt_size = next_fast_len(padded_size) f_signal = rfft(signal, n=fast_ftt_size) f_kernel = rfft(kernel, n=fast_ftt_size) f_result = f_signal * f_kernel result = irfft(f_result, n=fast_ftt_size) start_idx = (padded_size - truncate) // 2 return result[..., start_idx:start_idx + truncate]
def dct(x, dim=-1): """ Discrete cosine transform of type II, scaled to be orthonormal. This is the inverse of :func:`idct_ii` , and is equivalent to :func:`scipy.fftpack.dct` with ``norm="ortho"``. :param Tensor x: The input signal. :param int dim: Dimension along which to compute DCT. :rtype: Tensor """ if dim >= 0: dim -= x.dim() if dim != -1: y = x.reshape(x.shape[:dim + 1] + (-1, )).transpose(-1, -2) return dct(y).transpose(-1, -2).reshape(x.shape) # Ref: http://fourier.eng.hmc.edu/e161/lectures/dct/node2.html N = x.size(-1) # Step 1 y = torch.cat([x[..., ::2], x[..., 1::2].flip(-1)], dim=-1) # Step 2 Y = rfft(y, n=N) # Step 3 coef_real = torch.cos( torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype, device=x.device)) M = Y.size(-1) coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1) X = as_complex(coef) * Y # NB: if we use the full-length version Y_full = fft(y, n=N), then # the real part of the later half of X will be the flip # of the negative of the imaginary part of the first half X = torch.cat([X.real, -X.imag[..., 1:(N - M + 1)].flip(-1)], dim=-1) # orthogonalize scale = torch.cat([ x.new_tensor([math.sqrt(N)]), x.new_full((N - 1, ), math.sqrt(0.5 * N)) ]) return X / scale
def autocorrelation(input, dim=0): """ Computes the autocorrelation of samples at dimension ``dim``. Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation :param torch.Tensor input: the input tensor. :param int dim: the dimension to calculate autocorrelation. :returns torch.Tensor: autocorrelation of ``input``. """ if (not input.is_cuda) and (not torch.backends.mkl.is_available()): raise NotImplementedError( "For CPU tensor, this method is only supported " "with MKL installed.") # Adapted from Stan implementation # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp N = input.size(dim) M = next_fast_len(N) M2 = 2 * M # transpose dim with -1 for Fourier transform input = input.transpose(dim, -1) # centering and padding x centered_signal = input - input.mean(dim=-1, keepdim=True) # Fourier transform freqvec = torch.view_as_real(rfft(centered_signal, n=M2)) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec.pow(2).sum(-1) # inverse Fourier transform autocorr = irfft(freqvec_gram, n=M2) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] autocorr = autocorr / torch.tensor( range(N, 0, -1), dtype=input.dtype, device=input.device) autocorr = autocorr / autocorr[..., :1] return autocorr.transpose(dim, -1)
def autocorrelation(input, dim=0): """ Computes the autocorrelation of samples at dimension ``dim``. Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_. :param torch.Tensor input: the input tensor. :param int dim: the dimension to calculate autocorrelation. :returns torch.Tensor: autocorrelation of ``input``. """ # Adapted from Stan implementation # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp N = input.size(dim) M = next_fast_len(N) M2 = 2 * M # transpose dim with -1 for Fourier transform input = input.transpose(dim, -1) # centering and padding x centered_signal = input - input.mean(dim=-1, keepdim=True) # Fourier transform freqvec = torch.view_as_real(rfft(centered_signal, n=M2)) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec.pow(2).sum(-1) # inverse Fourier transform autocorr = irfft(freqvec_gram, n=M2) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] autocorr = autocorr / torch.tensor( range(N, 0, -1), dtype=input.dtype, device=input.device) autocorr = autocorr / autocorr[..., :1] return autocorr.transpose(dim, -1)
def train(model, train_loader, test_loader, mode='EDSR_Baseline', save_image_every=50, save_model_every=10, test_model_every=1, num_epochs=1000, device=None, refresh=True): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' today = datetime.datetime.now().strftime('%Y.%m.%d') result_dir = f'./results/{today}/{mode}' weight_dir = f'./weights/{today}/{mode}' logger_dir = f'./logger/{today}_{mode}' csv = f'./hist_{today}_{mode}.csv' if refresh: try: shutil.rmtree(result_dir) shutil.rmtree(weight_dir) shutil.rmtree(logger_dir) except FileNotFoundError: pass os.makedirs(result_dir, exist_ok=True) os.makedirs(weight_dir, exist_ok=True) os.makedirs(logger_dir, exist_ok=True) logger = SummaryWriter(log_dir=logger_dir, flush_secs=2) model = model.to(device) params = list(model.parameters()) optim = torch.optim.Adam(params, lr=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1000, gamma=0.99) criterion = torch.nn.L1Loss() start_time = time.time() print(f'Training Start || Mode: {mode}') step = 0 pfix = OrderedDict() pfix_test = OrderedDict() hist = dict() hist['mode'] = f'{today}_{mode}' for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']: hist[key] = [] for epoch in range(num_epochs): if epoch == 0: torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth') if epoch == 0: with torch.no_grad(): with tqdm( test_loader, desc= f'Mode: {mode} || Warming Up || Test Epoch {epoch}/{num_epochs}', position=0, leave=True) as pbar_test: psnrs = [] ssims = [] msssims = [] for lr, hr, fname in pbar_test: lr = lr.to(device) hr = hr.to(device) sr, features = model(lr) sr = quantize(sr) psnr, ssim, msssim = evaluate(hr, sr) psnrs.append(psnr) ssims.append(ssim) msssims.append(msssim) psnr_mean = np.array(psnrs).mean() ssim_mean = np.array(ssims).mean() msssim_mean = np.array(msssims).mean() pfix_test['psnr'] = f'{psnr:.4f}' pfix_test['ssim'] = f'{ssim:.4f}' pfix_test['msssim'] = f'{msssim:.4f}' pfix_test['psnr_mean'] = f'{psnr_mean:.4f}' pfix_test['ssim_mean'] = f'{ssim_mean:.4f}' pfix_test['msssim_mean'] = f'{msssim_mean:.4f}' pbar_test.set_postfix(pfix_test) if len(psnrs) > 1: break with tqdm(train_loader, desc=f'Mode: {mode} || Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar: psnrs = [] ssims = [] msssims = [] losses = [] for lr, hr, _ in pbar: lr = lr.to(device) hr = hr.to(device) # prediction sr, features = model(lr) #### srfft = fft.rfft(sr) hrfft = fft.rfft(hr) loss_fft = criterion(srfft.real, hrfft.real) loss_fft += criterion(srfft.imag, hrfft.imag) ##### # training loss = criterion(hr, sr) loss_tot = loss + 0.1 * loss_fft optim.zero_grad() loss_tot.backward() optim.step() scheduler.step() # training history elapsed_time = time.time() - start_time elapsed = sec2time(elapsed_time) pfix['Step'] = f'{step+1}' pfix['Loss'] = f'{loss.item():.4f}' pfix['Loss FFT'] = f'{loss_fft.item():.4f}' sr = quantize(sr) psnr, ssim, msssim = evaluate(hr, sr) psnrs.append(psnr) ssims.append(ssim) msssims.append(msssim) psnr_mean = np.array(psnrs).mean() ssim_mean = np.array(ssims).mean() msssim_mean = np.array(msssims).mean() pfix['PSNR'] = f'{psnr:.2f}' pfix['SSIM'] = f'{ssim:.4f}' # pfix['MSSSIM'] = f'{msssim:.4f}' pfix['PSNR_mean'] = f'{psnr_mean:.2f}' pfix['SSIM_mean'] = f'{ssim_mean:.4f}' # pfix['MSSSIM_mean'] = f'{msssim_mean:.4f}' free_gpu = get_gpu_memory()[0] pfix['free GPU'] = f'{free_gpu}MiB' pfix['Elapsed'] = f'{elapsed}' pbar.set_postfix(pfix) losses.append(loss.item()) if step % save_image_every == 0: z = torch.zeros_like(lr[0]) _, _, llr, _ = lr.shape _, _, hlr, _ = hr.shape if hlr // 2 == llr: xz = torch.cat((lr[0], z), dim=-2) elif hlr // 4 == llr: xz = torch.cat((lr[0], z, z, z), dim=-2) imsave([xz, sr[0], hr[0]], f'{result_dir}/epoch_{epoch+1}_iter_{step:05d}.jpg') step += 1 logger.add_scalar("Loss/train", np.array(losses).mean(), epoch + 1) logger.add_scalar("PSNR/train", psnr_mean, epoch + 1) logger.add_scalar("SSIM/train", ssim_mean, epoch + 1) if (epoch + 1) % save_model_every == 0: torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth') if (epoch + 1) % test_model_every == 0: with torch.no_grad(): with tqdm( test_loader, desc= f'Mode: {mode} || Test Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar_test: psnrs = [] ssims = [] msssims = [] for lr, hr, fname in pbar_test: fname = fname[0].split('/')[-1].split('.pt')[0] lr = lr.to(device) hr = hr.to(device) sr, features = model(lr) sr = quantize(sr) psnr, ssim, msssim = evaluate(hr, sr) psnrs.append(psnr) ssims.append(ssim) msssims.append(msssim) psnr_mean = np.array(psnrs).mean() ssim_mean = np.array(ssims).mean() msssim_mean = np.array(msssims).mean() pfix_test['psnr'] = f'{psnr:.4f}' pfix_test['ssim'] = f'{ssim:.4f}' pfix_test['msssim'] = f'{msssim:.4f}' pfix_test['psnr_mean'] = f'{psnr_mean:.4f}' pfix_test['ssim_mean'] = f'{ssim_mean:.4f}' pfix_test['msssim_mean'] = f'{msssim_mean:.4f}' pbar_test.set_postfix(pfix_test) z = torch.zeros_like(lr[0]) _, _, llr, _ = lr.shape _, _, hlr, _ = hr.shape if hlr // 2 == llr: xz = torch.cat((lr[0], z), dim=-2) elif hlr // 4 == llr: xz = torch.cat((lr[0], z, z, z), dim=-2) imsave([xz, sr[0], hr[0]], f'{result_dir}/{fname}.jpg') hist['epoch'].append(epoch + 1) hist['psnr'].append(psnr_mean) hist['ssim'].append(ssim_mean) hist['ms-ssim'].append(msssim_mean) logger.add_scalar("PSNR/test", psnr_mean, epoch + 1) logger.add_scalar("SSIM/test", ssim_mean, epoch + 1) logger.add_scalar("MS-SSIM/test", msssim_mean, epoch + 1) df = pd.DataFrame(hist) df.to_csv(csv)
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], groups: int) -> Tensor: output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride, padding, dilation) reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2) padded_input = F.pad(input, reversed_padding_repeated_twice) s: List[int] = [] weight_s: List[int] = [] for i, (x_size, w_size, d, st) in enumerate( zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)): s_size = max(x_size, w_size * d) # find s size that can be divided by stride and dilation rfft_even = 2 if i == len(stride) - 1 else 1 factor = _lcm(st * rfft_even, d * rfft_even) offset = s_size % factor if offset: s_size += factor - offset s.append(s_size) weight_s.append(s_size // d) X = rfftn(padded_input, s=s) W = rfft(weight, n=weight_s[-1]) # handle dilation # handle dilation for last dim if dilation[-1] > 1: W_neg_freq = W.flip(-1)[..., 1:] W_neg_freq.imag.mul_(-1) tmp = [W] for i in range(1, dilation[-1]): if i % 2: tmp.append(W_neg_freq) else: tmp.append(W[..., 1:]) W = torch.cat(tmp, -1) if len(weight_s) > 1: W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1))) repeats = (1, 1) + dilation[:-1] + (1, ) W.imag.mul_(-1) if sum(repeats) > W.ndim: W = W.repeat(*repeats) else: W.imag.mul_(-1) Y = _complex_matmul(X, W, groups) # handle stride if len(stride) > 1: for i, st in enumerate(stride[:-1]): if st > 1: Y = Y.reshape(*Y.shape[:i + 2], st, -1, *Y.shape[i + 3:]).mean(i + 2) Y = ifft(Y, dim=i + 2) Y = Y.as_strided( Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:], Y.stride()) if stride[-1] > 1: n_fft = Y.size(-1) * 2 - 2 new_n_fft = n_fft // stride[-1] step_size = new_n_fft // 2 strided_Y_size = step_size + 1 unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size) unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2, step_size) Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum( -2), unfolded_Y_imag[..., ::2, :].sum(-2) Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip( -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1) Y_real = Y_pos_real.add_(Y_neg_real) Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1) Y_imag = F.pad(Y_imag, [1, 1]) Y = torch.view_as_complex(torch.stack((Y_real, Y_imag), -1)).div_(stride[-1]) output = irfft(Y) # Remove extra padded values output = output[..., :output_size[-1]].contiguous() # Optionally, add a bias term before returning. if bias is not None: output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)] return output
def _fft_conv_transposend( input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Tuple[int], padding: Tuple[int], output_padding: Tuple[int], groups: int, dilation: Tuple[int], ) -> Tensor: output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:], stride, padding, output_padding, dilation) padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding)) s: List[int] = [] weight_s: List[int] = [] for i, (x_size, w_size, d, st) in enumerate( zip(padded_output_size, weight.shape[2:], dilation, stride)): s_size = max(x_size, w_size * d) # find s size that can be divided by stride and dilation rfft_even = 2 if i == len(stride) - 1 else 1 factor = _lcm(st * rfft_even, d * rfft_even) offset = s_size % factor if offset: s_size += factor - offset s.append(s_size // st) weight_s.append(s_size // d) X = rfft(input, n=s[-1]) W = rfft(weight, n=weight_s[-1]) if stride[-1] > 1: X_neg_freq = X.flip(-1)[..., 1:] X_neg_freq.imag.mul_(-1) tmp = [X] for i in range(1, stride[-1]): if i % 2: tmp.append(X_neg_freq) else: tmp.append(X[..., 1:]) X = torch.cat(tmp, -1) if dilation[-1] > 1: W_neg_freq = W.flip(-1)[..., 1:] W_neg_freq.imag.mul_(-1) tmp = [W] for i in range(1, dilation[-1]): if i % 2: tmp.append(W_neg_freq) else: tmp.append(W[..., 1:]) W = torch.cat(tmp, -1) if len(s) > 1: X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1))) W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1))) repeats = (1, 1) + stride[:-1] + (1, ) if sum(repeats) > X.ndim: X = X.repeat(*repeats) repeats = (1, 1) + dilation[:-1] + (1, ) if sum(repeats) > W.ndim: W = W.repeat(*repeats) Y = _complex_matmul(X, W, groups, True) output = irfftn(Y, dim=tuple(range(2, Y.ndim))) # Remove extra padded values index = (slice(None), ) * 2 + tuple( slice(p, o + p) for p, o in zip(padding, output_size)) output = output[index].contiguous() # Optionally, add a bias term before returning. if bias is not None: output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)] return output
def _new_rfft(x: torch.Tensor): z = new_fft.rfft(x, dim=-1) return torch.view_as_real(z)
def convolve1d( waveform, kernel, padding=0, pad_type="constant", stride=1, groups=1, use_fft=False, rotation_index=0, ): """Use torch.nn.functional to perform 1d padding and conv. Arguments --------- waveform : tensor The tensor to perform operations on. kernel : tensor The filter to apply during convolution padding : int or tuple The padding (pad_left, pad_right) to apply. If an integer is passed instead, this is passed to the conv1d function and pad_type is ignored. pad_type : str The type of padding to use. Passed directly to `torch.nn.functional.pad`, see PyTorch documentation for available options. stride : int The number of units to move each time convolution is applied. Passed to conv1d. Has no effect if `use_fft` is True. groups : int This option is passed to `conv1d` to split the input into groups for convolution. Input channels should be divisible by number of groups. use_fft : bool When `use_fft` is passed `True`, then compute the convolution in the spectral domain using complex multiply. This is more efficient on CPU when the size of the kernel is large (e.g. reverberation). WARNING: Without padding, circular convolution occurs. This makes little difference in the case of reverberation, but may make more difference with different kernels. rotation_index : int This option only applies if `use_fft` is true. If so, the kernel is rolled by this amount before convolution to shift the output location. Returns ------- The convolved waveform. Example ------- >>> from speechbrain.dataio.dataio import read_audio >>> signal = read_audio('samples/audio_samples/example1.wav') >>> signal = signal.unsqueeze(0).unsqueeze(2) >>> kernel = torch.rand(1, 10, 1) >>> signal = convolve1d(signal, kernel, padding=(9, 0)) """ if len(waveform.shape) != 3: raise ValueError("Convolve1D expects a 3-dimensional tensor") # Move time dimension last, which pad and fft and conv expect. waveform = waveform.transpose(2, 1) kernel = kernel.transpose(2, 1) # Padding can be a tuple (left_pad, right_pad) or an int if isinstance(padding, tuple): waveform = torch.nn.functional.pad( input=waveform, pad=padding, mode=pad_type, ) # This approach uses FFT, which is more efficient if the kernel is large if use_fft: # Pad kernel to same length as signal, ensuring correct alignment zero_length = waveform.size(-1) - kernel.size(-1) # Handle case where signal is shorter if zero_length < 0: kernel = kernel[..., :zero_length] zero_length = 0 # Perform rotation to ensure alignment zeros = torch.zeros(kernel.size(0), kernel.size(1), zero_length, device=kernel.device) after_index = kernel[..., rotation_index:] before_index = kernel[..., :rotation_index] kernel = torch.cat((after_index, zeros, before_index), dim=-1) # Multiply in frequency domain to convolve in time domain if version.parse(torch.__version__) > version.parse("1.6.0"): import torch.fft as fft result = fft.rfft(waveform) * fft.rfft(kernel) convolved = fft.irfft(result, n=waveform.size(-1)) else: f_signal = torch.rfft(waveform, 1) f_kernel = torch.rfft(kernel, 1) sig_real, sig_imag = f_signal.unbind(-1) ker_real, ker_imag = f_kernel.unbind(-1) f_result = torch.stack( [ sig_real * ker_real - sig_imag * ker_imag, sig_real * ker_imag + sig_imag * ker_real, ], dim=-1, ) convolved = torch.irfft(f_result, 1, signal_sizes=[waveform.size(-1)]) # Use the implementation given by torch, which should be efficient on GPU else: convolved = torch.nn.functional.conv1d( input=waveform, weight=kernel, stride=stride, groups=groups, padding=padding if not isinstance(padding, tuple) else 0, ) # Return time dimension to the second dimension. return convolved.transpose(2, 1)
def run(self): """ """ # TODO: Set up a way to consolidate and produce a final report of all experiments. for corpus in self.corpus_type: for embedding in self.embedding_type: benchmark = GetBenchmarkCorpra(corpus_type=corpus, parent=self.parent) benchmark.run() corpra_object = benchmark.corpra pre_trained_embedding = GetPreTrainedEmbeddingsStage( embedding_type=embedding, parent=self.parent) pre_trained_embedding.run() data = Benchmark2Embeddings(embedding_type=embedding, corpra_object=corpra_object, min_freq=self.min_freq, parent=self.parent, corpus_type=corpus) data.run() train_file = data.corpra_numeric['train'] train_labels = data.corpra_labels['train'] valid_file = None test_file = data.corpra_numeric['test'] test_labels = data.corpra_labels['test'] vectors = data.vocab.vectors dictionary = data.vocab.stoi for model_type in self.model_type: self.parent = f"{corpus}_{embedding}_{model_type}" self.logger.info("=" * 40) self.logger.info( f'Running experiments for Corpus: {corpus}' f', with Embedding: {embedding}' f', and Model Type: {model_type}.') self.logger.info("=" * 40) self.model_config.update({'model_type': model_type}) model = Model(dictionary_size=len(dictionary), embedding_vectors=vectors, embedding_size=vectors.size()[1], model_task=self.model_task, input_size=1, **self.model_config) # run one cycle to check backprop of ft model.train() model.zero_grad() X = torch.tensor(train_file[0]) states = generate_initial_states(model) X = X.reshape(X.size(0), 1) output, states = model(X, states) print('output', output) print('states', states) breakpoint() token_sample_nums = data.corpra_numeric['train'][0][1] token_sample_txts = [ data.vocab.itos[i] for i in token_sample_nums ][0] token_sample_vectors = [ data.vocab.vectors[i] for i in token_sample_nums ] print(token_sample_nums[:20]) print(token_sample_txts[:20]) print(len(token_sample_nums)) print(len(token_sample_vectors)) print(token_sample_vectors[0]) fig, ax = plt.subplots(2, 2) # This could be in the model class where we define embedding self.embedding = prep_embedding_layer( vectors=data.vocab.vectors, trainable=False) X = self.embedding(token_sample_nums) ft1 = ft.rfftn(input=X, norm="forward") ax[0, 0].plot(ft1[:, 1]) ax[0, 0].set_title( f'rfftn forward - regular 1-dim\n- ft shape({ft1.shape})', size=8) # similar to ft2, its nothing but a ft on list of numberes ft2 = ft.rfft(input=X, norm="forward") ax[0, 1].plot(ft2) ax[0, 1].set_title( f'rfft forward - regular all-dim\n- ft shape({ft2.shape})', size=8) # testing forward: this could be after embedding layer? ft3 = ft.rfftn(input=X, norm="forward") ax[1, 0].specgram(ft3[:, 1]) ax[1, 0].set_title( f'rfftn forward - specgram 1-dim\n- ft shape({ft3.shape})', size=8) # testing backward ft4 = ft.rfft(input=X, norm="forward") ax[1, 1].specgram(ft4.T) # ax[1, 1].specgram(ft4[:, 1], alpha=.5) ax[1, 1].set_title( f'rfft forward - specgram all-dim\n- ft shape({ft4.shape})', size=8) # display figures fig.suptitle( f'Plots of Fourier Transforms on a Single IMDB Document\n' f'With Glove 50d Embedding and Doc Len: {len(token_sample_nums)}.' ) fig.tight_layout() plt.show() return True
import torch from torch.fft import rfft, irfft from math import pi fs = 100 # 1 Hz sampled at 100Hz for 12s t = torch.linspace(0, 12 * 2 * pi, 1200) x1 = torch.sin(10 * t) x2 = torch.sin(20 * t) sp = torch.linspace(0, 50, 601) bp1 = bandpass(5, 15, 100, N=601) bp2 = bandpass(18, 22, 100, N=601) Fx1 = rfft(x1) Fx2 = rfft(x2) Fx1 /= rfft(x1).abs().max() Fx2 /= rfft(x2).abs().max() class TestBandpass(test.TestCase): def test_bandpass(self): result = bp1(x1 + x2)[100:-100] expect = x1[100:-100] self.assertClose(expect, result, tol=1e-1) result = bp2(x1 + x2)[100:-100] expect = x2[100:-100] self.assertClose(expect, result, tol=1e-1) def test_bandpass_resampled(self):