Example #1
0
    def __init__(self,
                 fft_size,
                 hop_size=None,
                 window_fn='hann',
                 normalize=False):
        super().__init__()

        if hop_size is None:
            hop_size = fft_size // 2

        self.fft_size, self.hop_size = fft_size, hop_size

        window = build_window(fft_size, window_fn=window_fn)  # (fft_size,)
        optimal_window = build_optimal_window(window, hop_size=hop_size)

        cos_bases, sin_bases = build_Fourier_bases(fft_size,
                                                   normalize=normalize)
        cos_bases, sin_bases = cos_bases[:fft_size // 2 +
                                         1] * optimal_window, -sin_bases[:fft_size
                                                                         // 2 +
                                                                         1] * optimal_window

        if not normalize:
            cos_bases = cos_bases / fft_size
            sin_bases = sin_bases / fft_size

        bases = torch.cat([cos_bases, sin_bases], dim=0)

        self.bases = nn.Parameter(bases.unsqueeze(dim=1), requires_grad=False)
Example #2
0
    def __init__(self, in_channels, out_channels, kernel_size, stride=None, window_fn='hann', trainable=False):
        super().__init__()
        
        assert out_channels == 1, "out_channels is expected 1, given {}".format(out_channels)
        assert in_channels % (kernel_size*2) == 0, "in_channels % (kernel_size*2) is given {}".format(in_channels % (kernel_size*2))
        
        self.kernel_size, self.stride = kernel_size, stride
        
        repeat = in_channels//(kernel_size*2)
        self.repeat = repeat
        
        window = build_window(kernel_size, window_fn=window_fn) # (kernel_size,)
        optimal_window = build_optimal_window(window, hop_size=stride)

        cos_basis, sin_basis = build_Fourier_basis(kernel_size, normalize=True)
        cos_basis, sin_basis = cos_basis * optimal_window / repeat, - sin_basis * optimal_window / repeat
        
        basis = None
        
        for idx in range(repeat):
            rolled_cos_basis = torch.roll(cos_basis, kernel_size//repeat*idx, dims=1)
            rolled_sin_basis = torch.roll(sin_basis, kernel_size//repeat*idx, dims=1)
            if basis is None:
                basis = torch.cat([rolled_cos_basis, rolled_sin_basis], dim=0)
            else:
                basis = torch.cat([basis, rolled_cos_basis, rolled_sin_basis], dim=0)
        
        self.basis = nn.Parameter(basis.unsqueeze(dim=1), requires_grad=trainable)