示例#1
0
def convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    Both signal and kernel must be 1-dimensional.
    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: torch.Tensor Convolution of signal with kernel. Returns the full convolution, i.e.,
        the output tensor will have size m + n - 1, where m is the length of the
        signal and n is the length of the kernel.
    """
    assert (signal.ndim == 1
            and kernel.ndim == 1), "signal and kernel must be 1-dimensional"
    m = signal.size(-1)
    n = kernel.size(-1)

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    return result[:padded_size]
示例#2
0
def fft_convolve(signal, kernel):
    signal = nn.functional.pad(signal, (0, signal.shape[-1]))
    kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))

    output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
    output = output[..., output.shape[-1] // 2:]

    return output
示例#3
0
 def forward(self, b, a):
     b_fft = rfft(b, n=self.nfft, dim=1)
     a_fft = rfft(a, n=self.nfft, dim=1)
     yabs = (torch.abs(b_fft) + self.eps) / (torch.abs(a_fft) + self.eps)
     # exclude first (DC) and last (Nyquist) bins
     #y = b_fft[:, 1:-1] / a_fft[:, 1:-1]
     assert torch.all(torch.isfinite(yabs))
     # calculate absolute value of complex tuple
     #yabs = torch.abs(y)
     # reintroduce first (DC) and last (Nyquist) bins
     #ones = torch.ones(a.shape[0], 1)
     #yabs = torch.cat([ones, yabs, ones], dim=-1)
     return yabs
示例#4
0
 def _process_batch(self, batch, target_has_phys=False):
     (c, f), target = batch
     if self.use_fft:
         f = rfft(f)
     predictions = self.model([c, f])
     batch_size = c[-1, -1] + 1
     predictions, target_tensor = self._format_target_and_prediction(
         predictions, c, target, batch_size, target_has_phys)
     if target_has_phys:
         if self.SE_only:
             loss = self.criterion.forward(
                 self.SE_mask * predictions[:, 0, :, :], self.SE_mask *
                 target_tensor[:, self.evaluator.z_index, :, :]
             ) * self.SE_factor
         else:
             loss = self.criterion.forward(
                 predictions[:, 0, :, :],
                 target_tensor[:, self.evaluator.z_index, :, :])
     else:
         if self.SE_only:
             loss = self.criterion.forward(
                 self.SE_mask * predictions,
                 self.SE_mask * target_tensor) * self.SE_factor
         else:
             loss = self.criterion.forward(predictions, target_tensor)
     loss *= (self.evaluator.nx * self.evaluator.ny * batch_size /
              c.shape[0])
     return loss, predictions, target_tensor, c, f
示例#5
0
 def __call__(self, arg):
     """ Apply filter on a (batched) signal. """
     N = arg.shape[-1]
     if not N in self._cached:
         # cache or print or error
         if self.strict:
             print(f"caching sparse operator for size {N}")
         self.cache(N)
     # read cache
     d, w = self._cached[N]
     if d.dim() == arg.dim(): 
         F_arg = rfft(w * arg)
         return irfft(d * F_arg)
     # batched
     F_arg = rfft(w * arg, dim=1)
     return irfft(d * F_arg, dim=1)
示例#6
0
def init_fft(x, pdim=3):
    Fx = rfft(x)
    a = 0.5 * torch.randn([pdim])
    _, js = torch.sort(Fx.abs(), descending=True)
    phi = torch.index_select(Fx, 0, js).angle()[:pdim]
    w = js[:pdim]
    return torch.stack([a, phi, w])
示例#7
0
def iirfreqz(h, ndft, squared=False, powerfloor=10**-3):
    """Compute frequency response of an IIR filter."""
    assert ndft > h.size(-1), "Incompatible DFT size!"
    h = F.pad(h, (0, ndft - h.size(-1)))
    hspec = fft.rfft(h, 1)
    hspec = (hspec[..., 0]**2 + hspec[..., 1]**2).clamp(min=powerfloor)
    if squared:
        return 1 / hspec
    return 1 / (hspec**.5)
示例#8
0
def firfreqz(h, ndft, squared=False):
    """Compute frequency response of an FIR filter."""
    assert ndft > h.size(-1), "Incompatible DFT size!"
    h = F.pad(h, (0, ndft - h.size(-1)))
    hspec = fft.rfft(h, 1)
    hspec = hspec[..., 0]**2 + hspec[..., 1]**2
    if squared:
        return hspec
    return hspec**.5
示例#9
0
def conv1d(signal, kernel, mode='fft_circular', cut=False, cut_lim=150):
    """
    signal M x N
    kernel N
    """
    kernel_size = int(kernel.shape[-1])

    if mode == 'direct':
        conved = F.conv1d(signal.unsqueeze(1),
                          kernel.flip(0).unsqueeze(0).unsqueeze(0),
                          padding=kernel_size - 1)[:, 0]

    elif mode == 'fft_circular':
        conved = irfft(rfft(signal) * rfft(kernel), signal.shape[-1])

    if cut:
        conved = conved[:, cut_lim:kernel_size + cut_lim]

    return conved
示例#10
0
 def init(self, x, dev=0.01):
     Fx = rfft(x)
     _, js = torch.sort(Fx.abs(), descending=True)
     n = self.n_modes
     phi = torch.index_select(Fx, 0, js).angle()[:n]
     w = js[:n]
     amp = dev * torch.randn([n]).abs() * torch.sqrt(torch.var(x))
     self.phi = Parameter(phi)
     self.w = Parameter(w.float())
     self.amp = Parameter(amp)
     return self
示例#11
0
def dct1(x):
    """
    Discrete Cosine Transform, Type I

    :param x: the input signal
    :return: the DCT-I of the signal over the last dimension
    """
    x_shape = x.shape
    x = x.view(-1, x_shape[-1])

    return fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1),
                    1)[:, :, 0].view(*x_shape)
示例#12
0
    def __call__(self, image: torch.Tensor,
                 context: ExpressionContext) -> torch.Tensor:
        std = 15. * torch.Tensor(context(self.std)).to(image.device).reshape(
            3, 1)

        space = fft.rfft(image.reshape(3, -1))
        space.real = space.real + torch.randn(space.shape).to(
            image.device) * std
        space.imag = space.imag + torch.randn(space.shape).to(
            image.device) * std

        return fft.irfft(space).reshape(*image.shape)
示例#13
0
def convolve(signal, kernel, mode='full'):
    """
    Computes the 1-d convolution of signal by kernel using FFTs.
    The two arguments should have the same rightmost dim, but may otherwise be
    arbitrarily broadcastable.

    :param torch.Tensor signal: A signal to convolve.
    :param torch.Tensor kernel: A convolution kernel.
    :param str mode: One of: 'full', 'valid', 'same'.
    :return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)``
        and ``n = kernel.size(-1)``, the rightmost size of the result will be:
        ``m + n - 1`` if mode is 'full';
        ``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or
        ``max(m, n)`` if mode is 'same'.
    :rtype torch.Tensor:
    """
    m = signal.size(-1)
    n = kernel.size(-1)
    if mode == 'full':
        truncate = m + n - 1
    elif mode == 'valid':
        truncate = max(m, n) - min(m, n) + 1
    elif mode == 'same':
        truncate = max(m, n)
    else:
        raise ValueError('Unknown mode: {}'.format(mode))

    # Compute convolution using fft.
    padded_size = m + n - 1
    # Round up for cheaper fft.
    fast_ftt_size = next_fast_len(padded_size)
    f_signal = rfft(signal, n=fast_ftt_size)
    f_kernel = rfft(kernel, n=fast_ftt_size)
    f_result = f_signal * f_kernel
    result = irfft(f_result, n=fast_ftt_size)

    start_idx = (padded_size - truncate) // 2
    return result[..., start_idx:start_idx + truncate]
示例#14
0
def dct(x, dim=-1):
    """
    Discrete cosine transform of type II, scaled to be orthonormal.

    This is the inverse of :func:`idct_ii` , and is equivalent to
    :func:`scipy.fftpack.dct` with ``norm="ortho"``.

    :param Tensor x: The input signal.
    :param int dim: Dimension along which to compute DCT.
    :rtype: Tensor
    """
    if dim >= 0:
        dim -= x.dim()
    if dim != -1:
        y = x.reshape(x.shape[:dim + 1] + (-1, )).transpose(-1, -2)
        return dct(y).transpose(-1, -2).reshape(x.shape)

    # Ref: http://fourier.eng.hmc.edu/e161/lectures/dct/node2.html
    N = x.size(-1)
    # Step 1
    y = torch.cat([x[..., ::2], x[..., 1::2].flip(-1)], dim=-1)
    # Step 2
    Y = rfft(y, n=N)
    # Step 3
    coef_real = torch.cos(
        torch.linspace(0, 0.5 * math.pi, N + 1, dtype=x.dtype,
                       device=x.device))
    M = Y.size(-1)
    coef = torch.stack([coef_real[:M], -coef_real[-M:].flip(-1)], dim=-1)
    X = as_complex(coef) * Y
    # NB: if we use the full-length version Y_full = fft(y, n=N), then
    # the real part of the later half of X will be the flip
    # of the negative of the imaginary part of the first half
    X = torch.cat([X.real, -X.imag[..., 1:(N - M + 1)].flip(-1)], dim=-1)
    # orthogonalize
    scale = torch.cat([
        x.new_tensor([math.sqrt(N)]),
        x.new_full((N - 1, ), math.sqrt(0.5 * N))
    ])
    return X / scale
示例#15
0
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    if (not input.is_cuda) and (not torch.backends.mkl.is_available()):
        raise NotImplementedError(
            "For CPU tensor, this method is only supported "
            "with MKL installed.")

    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
示例#16
0
def autocorrelation(input, dim=0):
    """
    Computes the autocorrelation of samples at dimension ``dim``.

    Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

    Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_.

    :param torch.Tensor input: the input tensor.
    :param int dim: the dimension to calculate autocorrelation.
    :returns torch.Tensor: autocorrelation of ``input``.
    """
    # Adapted from Stan implementation
    # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
    N = input.size(dim)
    M = next_fast_len(N)
    M2 = 2 * M

    # transpose dim with -1 for Fourier transform
    input = input.transpose(dim, -1)

    # centering and padding x
    centered_signal = input - input.mean(dim=-1, keepdim=True)

    # Fourier transform
    freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
    # take square of magnitude of freqvec (or freqvec x freqvec*)
    freqvec_gram = freqvec.pow(2).sum(-1)
    # inverse Fourier transform
    autocorr = irfft(freqvec_gram, n=M2)

    # truncate and normalize the result, then transpose back to original shape
    autocorr = autocorr[..., :N]
    autocorr = autocorr / torch.tensor(
        range(N, 0, -1), dtype=input.dtype, device=input.device)
    autocorr = autocorr / autocorr[..., :1]
    return autocorr.transpose(dim, -1)
示例#17
0
def train(model,
          train_loader,
          test_loader,
          mode='EDSR_Baseline',
          save_image_every=50,
          save_model_every=10,
          test_model_every=1,
          num_epochs=1000,
          device=None,
          refresh=True):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    today = datetime.datetime.now().strftime('%Y.%m.%d')

    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                step_size=1000,
                                                gamma=0.99)
    criterion = torch.nn.L1Loss()

    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    for epoch in range(num_epochs):

        if epoch == 0:
            torch.save(model.state_dict(),
                       f'{weight_dir}/epoch_{epoch+1:04d}.pth')

        if epoch == 0:
            with torch.no_grad():
                with tqdm(
                        test_loader,
                        desc=
                        f'Mode: {mode} || Warming Up || Test Epoch {epoch}/{num_epochs}',
                        position=0,
                        leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)
                        hr = hr.to(device)

                        sr, features = model(lr)
                        sr = quantize(sr)

                        psnr, ssim, msssim = evaluate(hr, sr)

                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)

                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr'] = f'{psnr:.4f}'
                        pfix_test['ssim'] = f'{ssim:.4f}'
                        pfix_test['msssim'] = f'{msssim:.4f}'
                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break

        with tqdm(train_loader,
                  desc=f'Mode: {mode} || Epoch {epoch+1}/{num_epochs}',
                  position=0,
                  leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                hr = hr.to(device)

                # prediction
                sr, features = model(lr)

                ####
                srfft = fft.rfft(sr)
                hrfft = fft.rfft(hr)

                loss_fft = criterion(srfft.real, hrfft.real)
                loss_fft += criterion(srfft.imag, hrfft.imag)
                #####

                # training
                loss = criterion(hr, sr)
                loss_tot = loss + 0.1 * loss_fft
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()

                # training history
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)
                pfix['Step'] = f'{step+1}'
                pfix['Loss'] = f'{loss.item():.4f}'
                pfix['Loss FFT'] = f'{loss_fft.item():.4f}'

                sr = quantize(sr)
                psnr, ssim, msssim = evaluate(hr, sr)

                psnrs.append(psnr)
                ssims.append(ssim)
                msssims.append(msssim)

                psnr_mean = np.array(psnrs).mean()
                ssim_mean = np.array(ssims).mean()
                msssim_mean = np.array(msssims).mean()

                pfix['PSNR'] = f'{psnr:.2f}'
                pfix['SSIM'] = f'{ssim:.4f}'
                # pfix['MSSSIM'] = f'{msssim:.4f}'
                pfix['PSNR_mean'] = f'{psnr_mean:.2f}'
                pfix['SSIM_mean'] = f'{ssim_mean:.4f}'
                # pfix['MSSSIM_mean'] = f'{msssim_mean:.4f}'

                free_gpu = get_gpu_memory()[0]

                pfix['free GPU'] = f'{free_gpu}MiB'
                pfix['Elapsed'] = f'{elapsed}'

                pbar.set_postfix(pfix)
                losses.append(loss.item())

                if step % save_image_every == 0:

                    z = torch.zeros_like(lr[0])
                    _, _, llr, _ = lr.shape
                    _, _, hlr, _ = hr.shape
                    if hlr // 2 == llr:
                        xz = torch.cat((lr[0], z), dim=-2)
                    elif hlr // 4 == llr:
                        xz = torch.cat((lr[0], z, z, z), dim=-2)
                    imsave([xz, sr[0], hr[0]],
                           f'{result_dir}/epoch_{epoch+1}_iter_{step:05d}.jpg')

                step += 1

            logger.add_scalar("Loss/train", np.array(losses).mean(), epoch + 1)
            logger.add_scalar("PSNR/train", psnr_mean, epoch + 1)
            logger.add_scalar("SSIM/train", ssim_mean, epoch + 1)

            if (epoch + 1) % save_model_every == 0:
                torch.save(model.state_dict(),
                           f'{weight_dir}/epoch_{epoch+1:04d}.pth')

            if (epoch + 1) % test_model_every == 0:

                with torch.no_grad():
                    with tqdm(
                            test_loader,
                            desc=
                            f'Mode: {mode} || Test Epoch {epoch+1}/{num_epochs}',
                            position=0,
                            leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:

                            fname = fname[0].split('/')[-1].split('.pt')[0]

                            lr = lr.to(device)
                            hr = hr.to(device)

                            sr, features = model(lr)
                            sr = quantize(sr)

                            psnr, ssim, msssim = evaluate(hr, sr)

                            psnrs.append(psnr)
                            ssims.append(ssim)
                            msssims.append(msssim)

                            psnr_mean = np.array(psnrs).mean()
                            ssim_mean = np.array(ssims).mean()
                            msssim_mean = np.array(msssims).mean()

                            pfix_test['psnr'] = f'{psnr:.4f}'
                            pfix_test['ssim'] = f'{ssim:.4f}'
                            pfix_test['msssim'] = f'{msssim:.4f}'
                            pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                            pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                            pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                            pbar_test.set_postfix(pfix_test)

                            z = torch.zeros_like(lr[0])
                            _, _, llr, _ = lr.shape
                            _, _, hlr, _ = hr.shape
                            if hlr // 2 == llr:
                                xz = torch.cat((lr[0], z), dim=-2)
                            elif hlr // 4 == llr:
                                xz = torch.cat((lr[0], z, z, z), dim=-2)
                            imsave([xz, sr[0], hr[0]],
                                   f'{result_dir}/{fname}.jpg')

                        hist['epoch'].append(epoch + 1)
                        hist['psnr'].append(psnr_mean)
                        hist['ssim'].append(ssim_mean)
                        hist['ms-ssim'].append(msssim_mean)

                        logger.add_scalar("PSNR/test", psnr_mean, epoch + 1)
                        logger.add_scalar("SSIM/test", ssim_mean, epoch + 1)
                        logger.add_scalar("MS-SSIM/test", msssim_mean,
                                          epoch + 1)

                        df = pd.DataFrame(hist)
                        df.to_csv(csv)
示例#18
0
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor],
                stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int],
                groups: int) -> Tensor:

    output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride,
                              padding, dilation)
    reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2)
    padded_input = F.pad(input, reversed_padding_repeated_twice)

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size)
        weight_s.append(s_size // d)

    X = rfftn(padded_input, s=s)

    W = rfft(weight, n=weight_s[-1])
    # handle dilation
    # handle dilation for last dim
    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(weight_s) > 1:
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + dilation[:-1] + (1, )
        W.imag.mul_(-1)
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)
    else:
        W.imag.mul_(-1)

    Y = _complex_matmul(X, W, groups)

    # handle stride
    if len(stride) > 1:
        for i, st in enumerate(stride[:-1]):
            if st > 1:
                Y = Y.reshape(*Y.shape[:i + 2], st, -1,
                              *Y.shape[i + 3:]).mean(i + 2)

            Y = ifft(Y, dim=i + 2)
            Y = Y.as_strided(
                Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:],
                Y.stride())

    if stride[-1] > 1:
        n_fft = Y.size(-1) * 2 - 2
        new_n_fft = n_fft // stride[-1]
        step_size = new_n_fft // 2
        strided_Y_size = step_size + 1

        unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
        unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2,
                                                 step_size)
        Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(
            -2), unfolded_Y_imag[..., ::2, :].sum(-2)
        Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip(
            -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1)

        Y_real = Y_pos_real.add_(Y_neg_real)
        Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1)
        Y_imag = F.pad(Y_imag, [1, 1])

        Y = torch.view_as_complex(torch.stack((Y_real, Y_imag),
                                              -1)).div_(stride[-1])

    output = irfft(Y)

    # Remove extra padded values
    output = output[..., :output_size[-1]].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
示例#19
0
def _fft_conv_transposend(
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    stride: Tuple[int],
    padding: Tuple[int],
    output_padding: Tuple[int],
    groups: int,
    dilation: Tuple[int],
) -> Tensor:

    output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:],
                                        stride, padding, output_padding,
                                        dilation)
    padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding))

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_output_size, weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size // st)
        weight_s.append(s_size // d)

    X = rfft(input, n=s[-1])
    W = rfft(weight, n=weight_s[-1])

    if stride[-1] > 1:
        X_neg_freq = X.flip(-1)[..., 1:]
        X_neg_freq.imag.mul_(-1)

        tmp = [X]
        for i in range(1, stride[-1]):
            if i % 2:
                tmp.append(X_neg_freq)
            else:
                tmp.append(X[..., 1:])

        X = torch.cat(tmp, -1)

    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(s) > 1:
        X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1)))
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + stride[:-1] + (1, )
        if sum(repeats) > X.ndim:
            X = X.repeat(*repeats)

        repeats = (1, 1) + dilation[:-1] + (1, )
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)

    Y = _complex_matmul(X, W, groups, True)

    output = irfftn(Y, dim=tuple(range(2, Y.ndim)))

    # Remove extra padded values
    index = (slice(None), ) * 2 + tuple(
        slice(p, o + p) for p, o in zip(padding, output_size))
    output = output[index].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
示例#20
0
def _new_rfft(x: torch.Tensor):
    z = new_fft.rfft(x, dim=-1)
    return torch.view_as_real(z)
示例#21
0
def convolve1d(
    waveform,
    kernel,
    padding=0,
    pad_type="constant",
    stride=1,
    groups=1,
    use_fft=False,
    rotation_index=0,
):
    """Use torch.nn.functional to perform 1d padding and conv.

    Arguments
    ---------
    waveform : tensor
        The tensor to perform operations on.
    kernel : tensor
        The filter to apply during convolution
    padding : int or tuple
        The padding (pad_left, pad_right) to apply.
        If an integer is passed instead, this is passed
        to the conv1d function and pad_type is ignored.
    pad_type : str
        The type of padding to use. Passed directly to
        `torch.nn.functional.pad`, see PyTorch documentation
        for available options.
    stride : int
        The number of units to move each time convolution is applied.
        Passed to conv1d. Has no effect if `use_fft` is True.
    groups : int
        This option is passed to `conv1d` to split the input into groups for
        convolution. Input channels should be divisible by number of groups.
    use_fft : bool
        When `use_fft` is passed `True`, then compute the convolution in the
        spectral domain using complex multiply. This is more efficient on CPU
        when the size of the kernel is large (e.g. reverberation). WARNING:
        Without padding, circular convolution occurs. This makes little
        difference in the case of reverberation, but may make more difference
        with different kernels.
    rotation_index : int
        This option only applies if `use_fft` is true. If so, the kernel is
        rolled by this amount before convolution to shift the output location.

    Returns
    -------
    The convolved waveform.

    Example
    -------
    >>> from speechbrain.dataio.dataio import read_audio
    >>> signal = read_audio('samples/audio_samples/example1.wav')
    >>> signal = signal.unsqueeze(0).unsqueeze(2)
    >>> kernel = torch.rand(1, 10, 1)
    >>> signal = convolve1d(signal, kernel, padding=(9, 0))
    """
    if len(waveform.shape) != 3:
        raise ValueError("Convolve1D expects a 3-dimensional tensor")

    # Move time dimension last, which pad and fft and conv expect.
    waveform = waveform.transpose(2, 1)
    kernel = kernel.transpose(2, 1)

    # Padding can be a tuple (left_pad, right_pad) or an int
    if isinstance(padding, tuple):
        waveform = torch.nn.functional.pad(
            input=waveform,
            pad=padding,
            mode=pad_type,
        )

    # This approach uses FFT, which is more efficient if the kernel is large
    if use_fft:

        # Pad kernel to same length as signal, ensuring correct alignment
        zero_length = waveform.size(-1) - kernel.size(-1)

        # Handle case where signal is shorter
        if zero_length < 0:
            kernel = kernel[..., :zero_length]
            zero_length = 0

        # Perform rotation to ensure alignment
        zeros = torch.zeros(kernel.size(0),
                            kernel.size(1),
                            zero_length,
                            device=kernel.device)
        after_index = kernel[..., rotation_index:]
        before_index = kernel[..., :rotation_index]
        kernel = torch.cat((after_index, zeros, before_index), dim=-1)

        # Multiply in frequency domain to convolve in time domain
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            import torch.fft as fft

            result = fft.rfft(waveform) * fft.rfft(kernel)
            convolved = fft.irfft(result, n=waveform.size(-1))
        else:
            f_signal = torch.rfft(waveform, 1)
            f_kernel = torch.rfft(kernel, 1)
            sig_real, sig_imag = f_signal.unbind(-1)
            ker_real, ker_imag = f_kernel.unbind(-1)
            f_result = torch.stack(
                [
                    sig_real * ker_real - sig_imag * ker_imag,
                    sig_real * ker_imag + sig_imag * ker_real,
                ],
                dim=-1,
            )
            convolved = torch.irfft(f_result,
                                    1,
                                    signal_sizes=[waveform.size(-1)])

    # Use the implementation given by torch, which should be efficient on GPU
    else:
        convolved = torch.nn.functional.conv1d(
            input=waveform,
            weight=kernel,
            stride=stride,
            groups=groups,
            padding=padding if not isinstance(padding, tuple) else 0,
        )

    # Return time dimension to the second dimension.
    return convolved.transpose(2, 1)
示例#22
0
    def run(self):
        """
        """
        # TODO: Set up a way to consolidate and produce a final report of all experiments.
        for corpus in self.corpus_type:

            for embedding in self.embedding_type:
                benchmark = GetBenchmarkCorpra(corpus_type=corpus,
                                               parent=self.parent)
                benchmark.run()
                corpra_object = benchmark.corpra

                pre_trained_embedding = GetPreTrainedEmbeddingsStage(
                    embedding_type=embedding, parent=self.parent)
                pre_trained_embedding.run()

                data = Benchmark2Embeddings(embedding_type=embedding,
                                            corpra_object=corpra_object,
                                            min_freq=self.min_freq,
                                            parent=self.parent,
                                            corpus_type=corpus)
                data.run()

                train_file = data.corpra_numeric['train']
                train_labels = data.corpra_labels['train']
                valid_file = None
                test_file = data.corpra_numeric['test']
                test_labels = data.corpra_labels['test']
                vectors = data.vocab.vectors
                dictionary = data.vocab.stoi

                for model_type in self.model_type:

                    self.parent = f"{corpus}_{embedding}_{model_type}"
                    self.logger.info("=" * 40)
                    self.logger.info(
                        f'Running experiments for Corpus: {corpus}'
                        f', with Embedding: {embedding}'
                        f', and Model Type: {model_type}.')
                    self.logger.info("=" * 40)
                    self.model_config.update({'model_type': model_type})

                    model = Model(dictionary_size=len(dictionary),
                                  embedding_vectors=vectors,
                                  embedding_size=vectors.size()[1],
                                  model_task=self.model_task,
                                  input_size=1,
                                  **self.model_config)

                    # run one cycle to check backprop of ft
                    model.train()

                    model.zero_grad()
                    X = torch.tensor(train_file[0])
                    states = generate_initial_states(model)
                    X = X.reshape(X.size(0), 1)

                    output, states = model(X, states)
                    print('output', output)
                    print('states', states)
                    breakpoint()

                token_sample_nums = data.corpra_numeric['train'][0][1]
                token_sample_txts = [
                    data.vocab.itos[i] for i in token_sample_nums
                ][0]
                token_sample_vectors = [
                    data.vocab.vectors[i] for i in token_sample_nums
                ]

                print(token_sample_nums[:20])
                print(token_sample_txts[:20])
                print(len(token_sample_nums))
                print(len(token_sample_vectors))
                print(token_sample_vectors[0])

                fig, ax = plt.subplots(2, 2)

                # This could be in the model class where we define embedding
                self.embedding = prep_embedding_layer(
                    vectors=data.vocab.vectors, trainable=False)
                X = self.embedding(token_sample_nums)

                ft1 = ft.rfftn(input=X, norm="forward")
                ax[0, 0].plot(ft1[:, 1])
                ax[0, 0].set_title(
                    f'rfftn forward - regular 1-dim\n- ft shape({ft1.shape})',
                    size=8)
                # similar to ft2, its nothing but a ft on list of numberes
                ft2 = ft.rfft(input=X, norm="forward")
                ax[0, 1].plot(ft2)
                ax[0, 1].set_title(
                    f'rfft forward - regular all-dim\n- ft shape({ft2.shape})',
                    size=8)

                # testing forward: this could be after embedding layer?
                ft3 = ft.rfftn(input=X, norm="forward")
                ax[1, 0].specgram(ft3[:, 1])
                ax[1, 0].set_title(
                    f'rfftn forward - specgram 1-dim\n- ft shape({ft3.shape})',
                    size=8)

                # testing backward
                ft4 = ft.rfft(input=X, norm="forward")
                ax[1, 1].specgram(ft4.T)
                # ax[1, 1].specgram(ft4[:, 1], alpha=.5)
                ax[1, 1].set_title(
                    f'rfft forward - specgram all-dim\n- ft shape({ft4.shape})',
                    size=8)
                # display figures
                fig.suptitle(
                    f'Plots of Fourier Transforms on a Single IMDB Document\n'
                    f'With Glove 50d Embedding and Doc Len: {len(token_sample_nums)}.'
                )
                fig.tight_layout()
                plt.show()

        return True
示例#23
0
import torch

from torch.fft import rfft, irfft
from math import pi

fs = 100
# 1 Hz sampled at 100Hz for 12s
t = torch.linspace(0, 12 * 2 * pi, 1200)
x1 = torch.sin(10 * t)
x2 = torch.sin(20 * t)
sp = torch.linspace(0, 50, 601)

bp1 = bandpass(5, 15, 100, N=601)
bp2 = bandpass(18, 22, 100, N=601)

Fx1 = rfft(x1)
Fx2 = rfft(x2)
Fx1 /= rfft(x1).abs().max()
Fx2 /= rfft(x2).abs().max()


class TestBandpass(test.TestCase):
    def test_bandpass(self):
        result = bp1(x1 + x2)[100:-100]
        expect = x1[100:-100]
        self.assertClose(expect, result, tol=1e-1)
        result = bp2(x1 + x2)[100:-100]
        expect = x2[100:-100]
        self.assertClose(expect, result, tol=1e-1)

    def test_bandpass_resampled(self):