예제 #1
0
    def istft(self, x):
        spec_r = x[:, 0, :, :]
        spec_i = x[:, 1, :, :]

        n_frames = spec_r.shape[1]

        spec_r = torch.cat(
            [spec_r, spec_r.index_select(dim=-1, index=self.idx)], dim=-1)
        spec_i = torch.cat(
            [spec_i, -spec_i.index_select(dim=-1, index=self.idx)], dim=-1)
        spec_r = spec_r.transpose(-1, -2).contiguous()
        spec_i = spec_i.transpose(-1, -2).contiguous()

        kernel = self.kernel_bw()
        kernel_r = kernel[..., 0].transpose(0, -1)
        kernel_i = kernel[..., 1].transpose(0, -1)

        sig = F.conv_transpose1d(spec_r,
                                 kernel_r,
                                 stride=self.hop_size,
                                 padding=self.win_size-self.hop_size) \
            - F.conv_transpose1d(spec_i,
                                 kernel_i,
                                 stride=self.hop_size,
                                 padding=self.win_size-self.hop_size)
        sig = sig.squeeze(dim=1)

        window = self.window(n_frames)
        sig = sig / (window + self.eps)

        return sig
예제 #2
0
    def inverse(self, stft_signal):
        """
        Args:
            stft_signal:
                if complex_representation == 'complex':
                        shape: [..., frames, num_fbins]
                if complex_representation == 'stacked':
                        shape: [..., frames, num_fbins, 2]
                if complex_representation == 'concat':
                        shape: [..., frames, 2 * num_fbins]
            num_samples, list or tensor of #samples
        Returns:
            [..., T]

        >>> stft_signal = torch.rand((2, 4, 10, 514))
        >>> torch_stft = STFT(512, 20, window_length=40, \
                              complex_representation='concat')
        >>> torch_signal = torch_stft.inverse(stft_signal)
        >>> torch_signal.shape
        torch.Size([2, 4, 180])
        >>> from paderbox.transform import istft
        >>> signal_np = stft_signal.numpy()
        >>> complex_signal = signal_np[..., :257] + 1j* signal_np[..., 257:]
        >>> time_signal = istft(complex_signal, 512, 20, window_length=40)
        >>> np.testing.assert_allclose(torch_signal, time_signal, atol=1e-5)
        """

        org_shape = stft_signal.shape
        x = stft_signal.view(-1, *org_shape[-2:])

        x = rearrange(x, '... frames feat -> ... feat frames')
        if self.complex_representation == 'stacked':
            signal_real, signal_imag = torch.chunk(stft_signal, 2, dim=-1)
        elif self.complex_representation == 'concat':
            signal_real, signal_imag = torch.chunk(x, 2, dim=-2)
        else:
            raise ValueError(
                f'Please choose one of the predefined output_types'
                f'{self.possible_out_types} not {self.complex_representation}')
        signal_real = torch.cat([signal_real, signal_real[:, 1:-1].flip(1)],
                                dim=1)
        signal_imag = torch.cat([signal_imag, -signal_imag[:, 1:-1].flip(1)],
                                dim=1)
        kernel_real = self.istft_kernel_real.to(signal_real)
        decoded_real = F.conv_transpose1d(signal_real,
                                          weight=kernel_real,
                                          stride=self.shift)
        kernel_imag = self.istft_kernel_imag.to(signal_imag)
        decoded_imag = F.conv_transpose1d(signal_imag,
                                          kernel_imag,
                                          stride=self.shift)
        time_signal = decoded_real + decoded_imag
        time_signal = time_signal.view(*org_shape[:-2], time_signal.shape[-1])
        if self.fading not in [None, False]:
            pad_width = (self.window_length - self.shift)
            if self.fading == 'half':
                pad_width /= 2
            cut_off = time_signal.shape[-1] - ceil(pad_width)
            time_signal = time_signal[..., int(pad_width):cut_off]
        return time_signal
    def backprop_conv1d_input(self, activation, module, R):
        stride, padding, kernel = module.stride, module.padding, module.kernel_size
        output_padding = \
            activation.size(2) - ((R.size(2) - 1) * stride[0] - 2 * padding[0] + kernel[0])

        W_L = torch.clamp(module.weight, min=0)
        W_H = torch.clamp(module.weight, max=0)

        L = torch.ones_like(activation, dtype=activation.dtype) * self.lowest
        H = torch.ones_like(activation, dtype=activation.dtype) * self.highest

        Z_O = F.conv1d(activation, module.weight, stride=stride, padding=padding)
        Z_L = F.conv1d(L, W_L, stride=stride, padding=padding)
        Z_H = F.conv1d(H, W_H, stride=stride, padding=padding)

        Z = Z_O - Z_L - Z_H + 1e-9
        pz = nn.ZeroPad2d((0, 0, 0, 336-40))
        Z = pz(Z[:, :, 2:-2])
        S = R / Z
        S = S[:, :40]

        C_O = F.conv_transpose1d(S, module.weight, stride=stride, padding=2, output_padding=0)
        C_L = F.conv_transpose1d(S, W_L, stride=stride, padding=2, output_padding=0)
        C_H = F.conv_transpose1d(S, W_H, stride=stride, padding=2, output_padding=0)

        R = activation * C_O - L * C_L - H * C_H

        return R
예제 #4
0
    def inverse(self,
                magnitude: torch.tensor,
                phase: torch.tensor,
                eps: float = 1e-9) -> torch.tensor:
        conc = torch.cat(
            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
            dim=1)
        inverse_transform = F.conv_transpose1d(conc,
                                               self.inverse_basis,
                                               stride=self.hop_length,
                                               padding=0)

        # remove window effect
        n_frames = conc.size(-1)
        inverse_size = inverse_transform.size(-1)

        window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)

        weight = self.square_window[:self.filter_length].unsqueeze(
            0).unsqueeze(0)
        window_filter = F.conv_transpose1d(window_filter,
                                           weight,
                                           stride=self.hop_length,
                                           padding=0)
        indices = torch.arange(inverse_size)
        window_filter = window_filter.squeeze() + eps
        window_filter = window_filter[indices]

        inverse_transform /= window_filter

        # scale by hop ratio
        inverse_transform *= self.filter_length / self.hop_length

        return inverse_transform[...,
                                 self.pad_amount:-self.pad_amount].squeeze(1)
예제 #5
0
    def forward(self, x, w0, w1, b1, y):
        x = F.conv_transpose1d(x,
                               w0,
                               None,
                               stride=2,
                               padding=1,
                               output_padding=1)
        x = F.conv_transpose1d(x,
                               w1,
                               b1,
                               stride=1,
                               padding=2,
                               dilation=2,
                               groups=2)

        y = F.conv_transpose1d(y,
                               self.w2,
                               self.b2,
                               stride=2,
                               padding=1,
                               output_padding=1)
        y = F.conv_transpose1d(y,
                               self.w3,
                               None,
                               stride=1,
                               padding=2,
                               dilation=2,
                               groups=3)
        return x, y
예제 #6
0
    def inverse(self, input1, input2, input_type='magphase'):
        """Call the inverse STFT (iSTFT), given tensors produced 
        by the `transform` function.

        Args:
            input1 (tensors): Magnitude/Real-part of STFT with shape 
            [num_batch, num_frequencies, num_frames]
            input2 (tensors): Phase/Imag-part of STFT with shape [
            [num_batch, num_frequencies, num_frames]
            input_type (str, optional): Mathematical meaning of input tensor's.
            Defaults to 'magphase'.

        Returns:
            tensors: Reconstructed audio given magnitude and phase. Of
                shape [num_batch, num_samples]
        """
        assert input_type in ['magphase', 'realimag']
        if input_type == 'realimag':
            real, imag = input1, input2
        else:
            real = input1*torch.cos(input2)
            imag = input1*torch.sin(input2)
        inputs = torch.cat([real, imag], dim=1)
        outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
        t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
        t = t.to(inputs.device)
        coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
        rm_start, rm_end = self.pad_amount, self.pad_amount+self.num_samples
        outputs = outputs[..., rm_start:rm_end]
        coff = coff[..., rm_start:rm_end]
        coffidx = torch.where(coff > 1e-8)
        outputs[coffidx] = outputs[coffidx]/(coff[coffidx])
        return outputs.squeeze(dim=1)
예제 #7
0
    def forward(self, spec):
        """ Applies transposed convolution to a TF representation.

        This is equivalent to overlap-add.

        Args:
            spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
                representation. (Output of :func:`Encoder.forward`).
        Returns:
            :class:`torch.Tensor`: The corresponding time domain signal.
        """
        filters = self.get_filters()
        if spec.ndim == 2:
            # Input is (freq, conv_time), output is (time)
            return F.conv_transpose1d(spec.unsqueeze(0),
                                      filters,
                                      stride=self.stride).squeeze()
        if spec.ndim == 3:
            # Input is (batch, freq, conv_time), output is (batch, 1, time)
            return F.conv_transpose1d(spec, filters, stride=self.stride)
        elif spec.ndim > 3:
            # Multiply all the left dimensions together and group them in the
            # batch. Make the convolution and restore.
            view_as = (-1, ) + spec.shape[-2:]
            out = F.conv_transpose1d(spec.view(view_as),
                                     filters,
                                     stride=self.stride)
            return out.view(spec.shape[:-2] + (-1, ))
예제 #8
0
    def forward(self, inputs, phase, cplx=False):
        """
        inputs : [B, N//2+1, T] (mags, real)
        phase: [B, N//2+1, T] (phase, imag)
        """

        if cplx:
            # N x 2F x T
            cspec = torch.cat([inputs, phase], dim=1)
        else:
            # N x F x T
            real = inputs * torch.cos(phase)
            imag = inputs * torch.sin(phase)
            # N x 2F x T
            cspec = torch.cat([real, imag], dim=1)
        # N x 1 x L
        outputs = F.conv_transpose1d(cspec, self.weight, stride=self.stride)

        # this is from torch-stft: https://github.com/pseeth/torch-stft
        # 1 x N x T
        t = self.window.repeat(1, 1, inputs.size(-1))**2
        # 1 x 1 x L
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
        outputs = outputs / (coff + 1e-8)
        #outputs = torch.where(coff == 0, outputs, outputs/coff)
        # N x 1 x L
        outputs = outputs[..., self.win_len -
                          self.stride:-(self.win_len - self.stride)]
        # N x L
        outputs = outputs.squeeze(1)
        return outputs
예제 #9
0
    def forward(self, x, mu=0):
        num_batches = x.shape[0]

        D_in = x.shape[-1]
        D_enc = F.conv1d(x, self.get_param("H"), stride=self.stride).shape[-1]

        self.lam = self.sigma * torch.sqrt(2 * torch.log(
            torch.zeros(1, device=self.device) + (self.num_conv * D_enc)))

        x_old = torch.zeros(num_batches,
                            self.num_conv,
                            D_enc,
                            device=self.device)
        yk = torch.zeros(num_batches, self.num_conv, D_enc, device=self.device)
        x_new = torch.zeros(num_batches,
                            self.num_conv,
                            D_enc,
                            device=self.device)
        t_old = torch.tensor(1, device=self.device).float()

        for t in range(self.T):
            H_yk_mu = (F.conv_transpose1d(
                yk, self.get_param("H"), stride=self.stride) + mu)
            if self.model_distribution == "gaussian":
                x_tilda = x - H_yk_mu
            elif self.model_distribution == "binomial":
                x_tilda = x - self.sigmoid(H_yk_mu)
            elif self.model_distribution == "poisson":
                x_tilda = x - torch.exp(H_yk_mu)

            x_new = (yk + F.conv1d(
                x_tilda, self.get_param("H"), stride=self.stride) / self.L)
            if self.twosided:
                x_new = self.relu(torch.abs(x_new) -
                                  self.lam / self.L) * torch.sign(x_new)
            else:
                x_new = self.relu(x_new - self.lam / self.L)

            t_new = (1 + torch.sqrt(1 + 4 * t_old * t_old)) / 2
            yk = x_new + (t_old - 1) / t_new * (x_new - x_old)

            x_old = x_new
            t_old = t_new

        z = F.conv_transpose1d(x_new, self.get_param("H"),
                               stride=self.stride) + mu

        return z, x_new, self.lam
 def forward(self, x):
     x = self.conv(x)
     out = F.conv_transpose1d(x,
                              self.upscale_weight,
                              stride=self.upscale,
                              groups=self.upscale_weight.size(0))
     return out
def upsampling_experiment(iterations=11,
                          activation=lambda x: F.leaky_relu(x, 0.2),
                          kernel_size=2,
                          stride=2,
                          padding=0):

    with torch.no_grad():
        in_channels = 16
        batch_size = 2
        t = torch.FloatTensor(batch_size, in_channels, 8).normal_(0, 1)

        for i in range(iterations):

            kernel = torch.FloatTensor(in_channels, in_channels,
                                       kernel_size).normal_(0, 1)
            t = F.conv_transpose1d(t, kernel, stride=stride, padding=padding)

            # kernel = torch.FloatTensor(
            #     in_channels * stride, in_channels, kernel_size).normal_(0, 1)
            # t = F.conv1d(t, kernel, stride=1, padding=1)
            # t = t\
            #     .permute(0, 2, 1)\
            #     .contiguous()\
            #     .view(batch_size, -1, in_channels)\
            #     .permute(0, 2, 1)\
            #     .contiguous()
            t = activation(t)

            print(t.shape)

    return t.data.cpu().numpy()
예제 #12
0
    def forward(self, input, T=None):
        """
        Args:
            input (batch_size, n_bins, n_frames, 2): n_bins = fft_size//2+1, n_frames = (T - fft_size)//hop_size + 1. n_frames may be different because of padding.
        Returns:
            output (batch_size, T):
        """
        fft_size, hop_size = self.fft_size, self.hop_size

        if T is None:
            padding = 2 * fft_size
        else:
            padding = (
                hop_size - (T - fft_size) % hop_size
            ) % hop_size + 2 * fft_size  # Assume that "fft_size%hop_size is 0"
        padding_left = padding // 2
        padding_right = padding - padding_left

        real, imag = input[..., 0], input[..., 1]
        input = torch.cat([real, imag, real[:, 1:-1], imag[:, 1:-1]], dim=1)
        bases = torch.cat([
            self.bases, self.bases[1:fft_size // 2],
            self.bases[-fft_size // 2:-1]
        ],
                          dim=0)

        output = F.conv_transpose1d(input, bases, stride=self.hop_size)
        output = F.pad(output, (-padding_left, -padding_right))
        output = output.squeeze(dim=1)

        return output
    def test_fake_quant_per_channel_other_prec(self):
        kernel_size = 3

        quant_desc_input = QuantDescriptor(num_bits=4)
        quant_desc_weight = QuantDescriptor(num_bits=3, axis=(1))

        quant_conv_object = quant_conv.QuantConvTranspose1d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 16)

        test_input_quantizer = TensorQuantizer(quant_desc_input)
        weight_quantizer = TensorQuantizer(quant_desc_weight)

        quant_input = test_input_quantizer(test_input)

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = weight_quantizer(weight_copy)

        out1 = F.conv_transpose1d(quant_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
예제 #14
0
    def forward(self, x):
        # Pad here if not using transposed conv
        input_size = x.shape[2]
        if self.padding != "valid":
            num_pad = (self.kernel_size - 1) // 2
            out = F.pad(x, (num_pad, num_pad), mode=self.padding)
        else:
            out = x

        # Lowpass filter (+ 0 insertion if transposed)
        if self.transpose:
            expected_steps = ((input_size - 1) * self.stride + 1)
            if self.padding == "valid":
                expected_steps = expected_steps - self.kernel_size + 1

            out = F.conv_transpose1d(out,
                                     self.filter,
                                     stride=self.stride,
                                     padding=0,
                                     groups=self.channels)
            diff_steps = out.shape[2] - expected_steps
            if diff_steps > 0:
                assert (diff_steps % 2 == 0)
                out = out[:, :, diff_steps // 2:-diff_steps // 2]
        else:
            assert (input_size % self.stride == 1)
            out = F.conv1d(out,
                           self.filter,
                           stride=self.stride,
                           padding=0,
                           groups=self.channels)

        return out
예제 #15
0
    def __init__(self, hyp, H=None):
        self.num_examples = hyp["num_examples"]
        # self.minibatch = hyp["minibatch"]
        self.x_dim = hyp["x_dim"]
        self.y_dim = hyp["y_dim"]
        self.H_dim = hyp["dictionary_dim"]
        self.num_conv = hyp["num_conv"]
        self.device = hyp["device"]
        self.random = hyp["random"]
        self.seed = hyp["seed"]
        x = generate_sparse_samples(self.num_examples,
                                    self.num_conv,
                                    self.x_dim,
                                    self.sparsity,
                                    device=self.device,
                                    random=self.random,
                                    seed=self.seed)
        if H is None:
            H = torch.randn((self.num_conv, 1, self.dictionary_dim),
                            device=self.device)
        else:
            self.H = H.to(self.device)
        self.H = F.normalize(self.H, p=2, dim=-1)
        weights = torch.zeros(self.num_conv, 1, self.dictionary_dim)

        for i in range(self.num_conv):
            weights[i, 0, :] = H[i]

        self.y = F.conv_transpose1d(x, weights)
예제 #16
0
 def forward(self, m, p, cplx=False, squeeze=False):
     """
     Accept phase & magnitude and output raw waveform
     args
         m, p: N x F x T
     return
         s: N x S
     """
     if p.dim() != m.dim() or p.dim() not in [2, 3]:
         raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format(
             p.dim()))
     # if F x T, reshape 1 x F x T
     if p.dim() == 2:
         p = th.unsqueeze(p, 0)
         m = th.unsqueeze(m, 0)
     if cplx:
         # N x 2F x T
         c = th.cat([m, p], dim=1)
     else:
         r = m * th.cos(p)
         i = m * th.sin(p)
         # N x 2F x T
         c = th.cat([r, i], dim=1)
     # N x 2F x T
     s = F.conv_transpose1d(c, self.K, stride=self.stride, padding=0)
     # N x S
     s = s.squeeze(1)
     if squeeze:
         s = th.squeeze(s)
     return s
예제 #17
0
    def inverse(self,
                magnitude: torch.tensor,
                phase: torch.tensor,
                eps: float = 1e-9) -> torch.tensor:
        conc = torch.cat(
            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
            dim=1)

        inverse_transform = F.conv_transpose1d(conc,
                                               self.inverse_basis,
                                               stride=self.hop_length,
                                               padding=0)

        # remove window effect
        if self.window is not None:
            n_frames = conc.size(-1)
            inverse_size = inverse_transform.size(-1)
            window_filter = torch.zeros(inverse_size).type_as(
                inverse_transform).fill_(eps)

            for idx in range(n_frames):
                sample = idx * self.hop_length
                window_filter[sample:min(inverse_size, sample + self.filter_length)] \
                    += self.square_window[:max(0, min(self.filter_length, inverse_size - sample))]

            inverse_transform /= window_filter

            # scale by hop ratio
            inverse_transform *= self.filter_length / self.hop_length

        return inverse_transform[...,
                                 self.pad_amount:-self.pad_amount].squeeze(1)
예제 #18
0
 def forward(self, x):
     return F.conv_transpose1d(
         input=x,
         weight=self.weight * self.scale,  # scale the weight on runtime
         bias=self.bias if self.use_bias else None,
         stride=self.stride,
         padding=self.pad)
예제 #19
0
    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0,
        )

        if self.window is not None:
            window_sum = librosa.filters.window_sumsquare(
                self.window,
                magnitude.size(-1),
                hop_length=self.hop_length,
                win_length=self.win_length,
                n_fft=self.filter_length,
                dtype=np.float32,
            )
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False).to(
                magnitude.device
            )
            inverse_transform[..., approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= self.filter_length / self.hop_length

        inverse_transform = inverse_transform[..., self.pad_amount :]
        inverse_transform = inverse_transform[..., : -self.pad_amount :]
        inverse_transform = inverse_transform.squeeze(1)

        return inverse_transform
예제 #20
0
 def forward(self, x):
     if x.dim() == 5:
         x = torch.cat([*x.chunk()], 2).squeeze()
     return Q(
         F.conv_transpose1d(x, self.weight, self.bias, self.stride,
                            self.padding, self.output_padding, self.groups,
                            self.dilation))
예제 #21
0
    def forward(self, signal):
        """ Algorithm 1 from the paper, runs convolutional learned FISTA. The signals is denoted 'y' in the paper."""
        torch.set_default_tensor_type(torch.DoubleTensor)
        x = torch.zeros(self.code_size)
        prev_x = torch.zeros(self.code_size)
        s = 0
        prev_s = 0
        pad = self.kernel_size//2   #TODO: figure padding
        soft_threshold = nn.Softshrink(self.lam / self.L)
        for t in range(self.T):
            s = (1+(1+4*prev_s**2)**0.5)/2      # line 3 in Algorithm 1

            w = x + ((prev_s - 1)/s)*(x - prev_x)     # line 4

            # line 5
            print("H's shape: {}, w's shape: {}".format(self.H.shape, w.shape))
            v = torch.mm(self.H, w)
            print("v's shape: {}, signal's shape: {}, x: {}, local dictionary's shape: {}".format(v.shape, signal.shape, x.shape, self.local_dictionary.shape))
            v = signal - v
            c = w + (1/self.L)*F.conv_transpose1d(v, self.local_dictionary, padding=pad)       # TODO: check translation

            prev_x = x      # line 6
            x = soft_threshold(c)

        return x        # maybe return F.conv1d(z, self.H)?
예제 #22
0
    def forward(self, inputs: 'Tensor') -> 'Tensor':
        """ Forward pass method for transposed convolution layer.

        Parameters
        ----------
        inputs : torch Tensor
            input tensor for transposed convolution layer.

        Returns
        -------
        torch Tensor
            result of convolutional operation applied to the input tensor.
        """
        conv_args = (inputs, self.weight, self.bias, self.stride, 0, 0,
                     self.groups, self.dilation)

        if self.ndims == 2:
            x = F.conv_transpose1d(*conv_args)
        elif self.ndims == 3:
            x = F.conv_transpose2d(*conv_args)
        elif self.ndims == 4:
            x = F.conv_transpose3d(*conv_args)

        if self.crop:
            x = crop(x, self.crop_sizes)
        return x
예제 #23
0
파일: stft.py 프로젝트: volcacius/brevitas
    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat(
            [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        if self.window is not None:
            window_sum = window_sumsquare(
                self.window, magnitude.size(-1), hop_length=self.hop_length,
                win_length=self.win_length, n_fft=self.filter_length,
                dtype=np.float32)
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(
                np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(
                torch.from_numpy(window_sum), requires_grad=False)
            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= float(self.filter_length) / self.hop_length

        inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
        inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]

        return inverse_transform
예제 #24
0
    def forward(self,
                x: torch.Tensor,
                output_size: Optional[List[int]] = None) -> torch.Tensor:
        """
        we have:
        w(float) -- quant - dequant \
        x(float) ------------- F.convTranspose1d ---
        In the full model, we will see
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.convTranspose1d --- *quant - dequant
        and the backend should be able to fuse the ops with `*` into a quantized conv1d
        """

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,
            self.dilation)  # type: ignore[arg-type]

        weight_quant_dequant = self.get_weight()
        result = F.conv_transpose1d(x, weight_quant_dequant, self.bias,
                                    self.stride, self.padding, output_padding,
                                    self.groups, self.dilation)
        return result
예제 #25
0
 def conv_transpose1d(self, x, weight, bias, output_padding):
     if self.padding_type == PaddingType.SAME:
         out = self.transposeconv1d_same_padding(x, weight, bias,
                                                 output_padding)
     else:
         out = conv_transpose1d(x, weight, bias, self.stride, self.padding,
                                output_padding, self.groups, self.dilation)
     return out
예제 #26
0
 def forward(self, x):
     si = x.size()
     x = x.reshape(si[0], si[1] * si[2], si[3])
     output = F.conv_transpose1d(x, self.weight, self.bias, self.stride,
                                 self.padding, self.output_padding,
                                 self.num, self.dilation)
     so = output.size()[-1]
     return output.view(si[0], si[1], self.out_channels, so)
예제 #27
0
    def decode(self, X):
        for i, layer in enumerate(reversed(self.encoder_net)):
            # skip the last layer of the encoder, which is a non-linearity.
            # Do not want non-linearity -> non-linearity.
            if not i:
                continue

            if isinstance(layer, nn.Conv1d):
                if layer.bias is not None:
                    X = X + layer.bias.unsqueeze(1)

                st_pad = layer.stride[0] // 2
                X = F.conv_transpose1d(X, layer.weight, None, layer.stride,
                                       layer.padding, st_pad, layer.groups,
                                       layer.dilation)
            elif isinstance(layer, CausalConv1d):
                if layer.conv.bias is not None:
                    X = X + layer.bias.unsqueeze(1)

                # No symmetrical padding in temporal dimension
                padding = 0
                st_pad = layer.conv.stride[0] // 2
                X = F.conv_transpose1d(X, layer.conv.weight, None,
                                       layer.conv.stride, padding, st_pad,
                                       layer.conv.groups, layer.conv.dilation)

                # remove all implicit T padding from the end (therefore causal)
                X = layer.chomp(X)

            elif isinstance(layer, nn.Linear):
                X = F.linear(X, layer.weight.transpose(0, 1), layer.bias)
            elif isinstance(layer, nn.AvgPool1d):
                X = F.interpolate(X, scale_factor=layer.stride)
            elif isinstance(layer, Squeezer):
                X = torch.unsqueeze(X, dim=2)
            elif isinstance(layer, Flatten):
                # dirty hack to determine the correct C dimensionality
                # [i-2] if one dense, [i-4] if two dense
                X = X.view(
                    X.shape[0],
                    list(self.encoder_net.children())[i - 2].weight.shape[0],
                    -1)
            # elif not isinstance(layer, nn.BatchNorm1d):
            else:
                X = layer(X)
        return torch.tanh(X)
예제 #28
0
 def _apply_kernel(signal, kernel, reflect):
     signal = signal.view(-1, *org_shape[-2:])
     signal = rearrange(signal, '... frames feat -> ... feat frames')
     if reflect:
         signal = torch.cat([signal, -signal[:, 1:-1].flip(1)], dim=1)
     else:
         signal = torch.cat([signal, signal[:, 1:-1].flip(1)], dim=1)
     return F.conv_transpose1d(signal, weight=kernel, stride=self.shift)
예제 #29
0
 def forward_pass(m, in_tensor, wieght):
     return F.conv_transpose1d(in_tensor,
                               weight,
                               stride=m.stride,
                               padding=m.padding,
                               output_padding=m.output_padding,
                               dilation=m.dilation,
                               groups=m.groups).detach()
예제 #30
0
 def backward_pass(layer, in_tensor, weight):
     return F.conv_transpose1d(in_tensor,
                               weight,
                               stride=layer.stride,
                               padding=layer.padding,
                               output_padding=layer.output_padding,
                               groups=layer.groups,
                               dilation=layer.dilation).detach()
예제 #31
0
 def transposed_convolve(self, x):
     x = F.conv_transpose1d(
         x, self.filter_bank, padding=self.filter_bank.shape[-1] // 2)
     return x