示例#1
0
def test_stft_istft_identity(ctx, window_size, stride, fft_size, window_type,
                             center, pad_mode):
    backend = ctx.backend[0].split(":")[0]
    if backend == 'cuda':
        pytest.skip(
            'CUDA Convolution N-D is only supported in CUDNN extension')

    x_shape = create_stft_input_shape(window_size)
    x = np.random.randn(*x_shape)

    # Skip for NOLA condition violation
    length = x_shape[1]
    if is_nola_violation(window_type, window_size, stride, fft_size, length,
                         center):
        pytest.skip('NOLA condition violation.')
        return

    x = nn.Variable.from_numpy_array(x)
    with nn.context_scope(ctx):
        yr, yi = F.stft(x, window_size, stride, fft_size, window_type, center,
                        pad_mode)
        z = F.istft(yr,
                    yi,
                    window_size,
                    stride,
                    fft_size,
                    window_type,
                    center,
                    pad_mode="constant")
    z.forward()

    assert (np.allclose(x.d, z.d, atol=1e-5, rtol=1e-5))
示例#2
0
def test_istft(ctx, window_size, stride, fft_size, window_type, center):
    backend = ctx.backend[0].split(":")[0]
    if backend == 'cuda':
        pytest.skip('CUDA Convolution N-D is only supported in CUDNN extension')

    # clear all previous STFT conv/deconv kernels
    nn.clear_parameters()

    # Make sure that iSTFT(STFT(x)) = x
    x = np.random.randn(1, window_size * 10)

    nx = nn.Variable.from_numpy_array(x)
    with nn.context_scope(ctx):
        nyr, nyi = F.stft(nx,
                          window_size=window_size,
                          stride=stride,
                          fft_size=fft_size,
                          window_type=window_type,
                          center=center)
        nz = F.istft(nyr, nyi,
                     window_size=window_size,
                     stride=stride,
                     fft_size=fft_size,
                     window_type=window_type,
                     center=center)
    nz.forward()

    invalid = window_size - stride
    assert(np.allclose(nx.d[:, invalid:-invalid],
                       nz.d[:, invalid:-invalid],
                       atol=1e-5, rtol=1e-5))
示例#3
0
def stft_backward(inputs,
                  window_size,
                  stride,
                  fft_size,
                  window_type='hanning',
                  center=True,
                  pad_mode='reflect',
                  as_istft_backward=False):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dyr = inputs[0]
    dyi = inputs[1]

    dx = F.istft(dyr,
                 dyi,
                 window_size,
                 stride,
                 fft_size,
                 window_type,
                 center,
                 pad_mode,
                 as_stft_backward=not as_istft_backward)
    return dx
示例#4
0
def test_istft(window_size, stride, fft_size, window_type, center):
    # clear all previous STFT conv/deconv kernels
    nn.clear_parameters()

    # Make sure that iSTFT(STFT(x)) = x
    x = np.random.randn(1, window_size * 10)

    nx = nn.Variable.from_numpy_array(x)
    nyr, nyi = F.stft(nx,
                      window_size=window_size,
                      stride=stride,
                      fft_size=fft_size,
                      window_type=window_type,
                      center=center)
    nz = F.istft(nyr,
                 nyi,
                 window_size=window_size,
                 stride=stride,
                 fft_size=fft_size,
                 window_type=window_type,
                 center=center)
    nz.forward()

    invalid = window_size - stride
    assert (np.allclose(nx.d[:, invalid:-invalid],
                        nz.d[:, invalid:-invalid],
                        atol=1e-5,
                        rtol=1e-5))
示例#5
0
def check_nola_violation(y_r, y_i, window_size, stride, fft_size, window_type,
                         center, pad_mode, as_stft_backward):
    # Check reference raise
    try:
        # Use PyTorch to check NOLA condition becasue librosa does not raise.
        # If PyTorch is not installed, NOLA condition test for reference is skipped.
        import torch

        def ref_istft_torch(y_r, y_i, window_size, stride, fft_size,
                            window_type, center):
            y_r = np.reshape(y_r, y_r.shape + (1, ))
            y_i = np.reshape(y_i, y_i.shape + (1, ))
            y = np.concatenate((y_r, y_i), axis=3)

            y = torch.tensor(y)
            y = torch.view_as_complex(y)
            window = torch.tensor(create_window_func(window_type, window_size))

            x_shape = create_stft_input_shape(window_size)
            length = x_shape[1]

            x = torch.istft(y,
                            n_fft=fft_size,
                            hop_length=stride,
                            win_length=window_size,
                            window=window,
                            center=center,
                            length=length)
            return x

        with pytest.raises(RuntimeError, match=r"window overlap add"):
            # NOLA condition is checked during forward execution.
            ref_istft_torch(y_r, y_i, window_size, stride, fft_size,
                            window_type, center)
    except:
        # Install PyTorch to check NOLA condition validation of reference istft.
        pass

    # Check NNabla raise
    y_r = nn.Variable.from_numpy_array(y_r)
    y_i = nn.Variable.from_numpy_array(y_i)
    with pytest.raises(
            RuntimeError,
            match=r"NOLA\(Nonzero Overlap Add\) condition is not met."):
        # NOLA condition is checked during setup.
        _ = F.istft(y_r, y_i, window_size, stride, fft_size, window_type,
                    center, pad_mode, as_stft_backward)
示例#6
0
def ref_istft(y_r, y_i, window_size, stride, fft_size, window_type, center,
              pad_mode, as_stft_backward):
    if not as_stft_backward:
        # Use librosa.istft as the forward reference.

        # Convert to librosa.istft input format.
        y = y_r + 1j * y_i

        # Get original signal length.
        x_shape = create_stft_input_shape(window_size)
        length = x_shape[1]

        # librosa.istft does not support batched input.
        b = y.shape[0]
        xs = []
        for i in range(b):
            x = librosa.istft(y[i],
                              hop_length=stride,
                              win_length=window_size,
                              window=window_type,
                              center=center,
                              length=length)
            xs.append(x)
        return np.array(xs)
    else:
        # Use F.stft backward as the reference

        y_r = nn.Variable.from_numpy_array(y_r)
        y_i = nn.Variable.from_numpy_array(y_i)

        # Just create stft inputs
        x = F.istft(y_r, y_i, window_size, stride, fft_size, window_type,
                    center, pad_mode, True)

        # Execute istft backward
        x.need_grad = True
        x.grad.zero()
        z_r, z_i = F.stft(x, window_size, stride, fft_size, window_type,
                          center, pad_mode)

        z_r.g = y_r.d
        z_i.g = y_i.d
        z = F.sink(z_r, z_i, one_input_grad=False)
        z.forward()
        z.backward()

        return x.g
示例#7
0
文件: test_stft.py 项目: sony/nnabla
def ref_stft(x, window_size, stride, fft_size, window_type, center, pad_mode,
             as_istft_backward):
    if not as_istft_backward:
        # Use librosa.stft as the forward reference.

        # librosa.stft does not support batched input.
        window_type = 'hann' if window_type == 'hanning' else window_type
        b = x.shape[0]
        ys = []
        for i in range(b):
            y = librosa.stft(x[i],
                             n_fft=fft_size,
                             hop_length=stride,
                             win_length=window_size,
                             window=window_type,
                             center=center,
                             pad_mode=pad_mode)
            ys.append(y)

        # Convert to nnabla stft output format
        ys = np.array(ys)
        y_r = ys.real
        y_i = ys.imag

        return y_r, y_i
    else:
        # Use F.istft backward as the reference

        x = nn.Variable.from_numpy_array(x)

        # Just create istft inputs
        y_r, y_i = F.stft(x, window_size, stride, fft_size, window_type,
                          center, pad_mode)

        # Execute istft backward
        y_r.need_grad = True
        y_i.need_grad = True
        y_r.grad.zero()
        y_i.grad.zero()
        z = F.istft(y_r, y_i, window_size, stride, fft_size, window_type,
                    center, pad_mode)

        z.forward()
        z.backward(x.data)

        return y_r.g, y_i.g
示例#8
0
def istft(y_r,
          y_i,
          window_size,
          stride,
          fft_size,
          window_type='hanning',
          center=True):
    '''Workaround wrapper of ISTFT for fixing a bug in nnabla<=1.15.0
    '''
    from utils import get_nnabla_version_integer
    if get_nnabla_version_integer() > 11500:
        return F.istft(**locals())
    import numpy as np
    from nnabla.parameter import get_parameter, get_parameter_or_create
    conv_cos = get_parameter('conv_cos')
    conv_sin = get_parameter('conv_sin')

    if conv_cos is None or conv_sin is None:
        if window_type == 'hanning':
            window_func = np.hanning(window_size + 1)[:-1]
        elif window_type == 'hamming':
            window_func = np.hamming(window_size + 1)[:-1]
        elif window_type == 'rectangular' or window_type is None:
            window_func = np.ones(window_size)
        else:
            raise ValueError("Unknown window type {}.".format(window_type))

        # pad window if `fft_size > window_size`
        if fft_size > window_size:
            diff = fft_size - window_size
            window_func = np.pad(window_func, (diff // 2, diff - diff // 2),
                                 mode='constant')
        elif fft_size < window_size:
            raise ValueError(
                "FFT size has to be as least as large as window size.")

        # compute inverse STFT filter coefficients
        if fft_size % stride != 0:
            raise ValueError("FFT size needs to be a multiple of stride.")

        inv_window_func = np.zeros_like(window_func)
        for s in range(0, fft_size, stride):
            inv_window_func += np.roll(np.square(window_func), s)

        mat_cos = np.zeros((fft_size // 2 + 1, 1, fft_size))
        mat_sin = np.zeros((fft_size // 2 + 1, 1, fft_size))

        for w in range(fft_size // 2 + 1):
            alpha = 1.0 if w == 0 or w == fft_size // 2 else 2.0
            alpha /= fft_size
            for t in range(fft_size):
                mat_cos[w, 0, t] = alpha * \
                    np.cos(2. * np.pi * w * t / fft_size)
                mat_sin[w, 0, t] = alpha * \
                    np.sin(2. * np.pi * w * t / fft_size)
        mat_cos = mat_cos * window_func / inv_window_func
        mat_sin = mat_sin * window_func / inv_window_func

        conv_cos = get_parameter_or_create('conv_cos',
                                           initializer=mat_cos,
                                           need_grad=False)
        conv_sin = get_parameter_or_create('conv_sin',
                                           initializer=mat_sin,
                                           need_grad=False)

    # compute inverse STFT
    x_cos = F.deconvolution(y_r, conv_cos, stride=(stride, ))
    x_sin = F.deconvolution(y_i, conv_sin, stride=(stride, ))

    x = F.reshape(x_cos - x_sin, (x_cos.shape[0], x_cos.shape[2]))

    if center:
        x = x[:, fft_size // 2:-fft_size // 2]

    return x
示例#9
0
    def __call__(self, x, test=False):
        '''
        Input: (nb_samples, nb_channels, nb_timesteps)
            or (nb_frames, nb_samples, nb_channels, nb_bins)
        Outputs: Input Power/Mag Spectrogram, Output Power/Mag Spectrogram and Predictd sources
        '''
        self.test = test
        fft_real, fft_imag = get_stft(x, n_fft=self.n_fft, n_hop=self.n_hop)
        x_theta = F.atan2(fft_imag, fft_real)
        x = get_spectogram(fft_real, fft_imag, mono=(self.nb_channels == 1))

        nb_frames, nb_samples, nb_channels, nb_bins = x.shape

        mix_spec = F.identity(x)
        x = x[..., :self.nb_bins]

        # clone
        x_bass = F.identity(x)
        x_drums = F.identity(x)
        x_vocals = F.identity(x)
        x_other = F.identity(x)

        # shift and scale input to mean=0 std=1 (across all bins)
        x_bass += self.input_mean_bass
        x_drums += self.input_mean_drums
        x_vocals += self.input_mean_vocals
        x_other += self.input_mean_other

        x_bass *= self.input_scale_bass
        x_drums *= self.input_scale_drums
        x_vocals *= self.input_scale_vocals
        x_other *= self.input_scale_other

        # encode and normalize every instance in a batch
        x_bass = self.fc_bn(x_bass,
                            self.hidden_size,
                            "fc1_bass",
                            activation='tanh')
        x_drums = self.fc_bn(x_drums,
                             self.hidden_size,
                             "fc1_drums",
                             activation='tanh')
        x_vocals = self.fc_bn(x_vocals,
                              self.hidden_size,
                              "fc1_vocals",
                              activation='tanh')
        x_other = self.fc_bn(x_other,
                             self.hidden_size,
                             "fc1_other",
                             activation='tanh')

        # Average the sources
        cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # apply 3-layers of stacked LSTM
        lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass")
        lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums")
        lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals")
        lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other")

        # lstm skip connection
        x_bass = F.concatenate(x_bass, lstm_out_bass)
        x_drums = F.concatenate(x_drums, lstm_out_drums)
        x_vocals = F.concatenate(x_vocals, lstm_out_vocals)
        x_other = F.concatenate(x_other, lstm_out_other)

        cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0

        # first dense stage + batch norm
        x_bass = self.fc_bn(cross_2,
                            self.hidden_size,
                            "fc2_bass",
                            activation='relu')
        x_drums = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_drums",
                             activation='relu')
        x_vocals = self.fc_bn(cross_2,
                              self.hidden_size,
                              "fc2_vocals",
                              activation='relu')
        x_other = self.fc_bn(cross_2,
                             self.hidden_size,
                             "fc2_other",
                             activation='relu')

        # second dense stage + batch norm
        x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass")
        x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums")
        x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals")
        x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other")

        # reshape back to original dim
        x_bass = F.reshape(
            x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_drums = F.reshape(
            x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_vocals = F.reshape(
            x_vocals,
            (nb_frames, nb_samples, nb_channels, self.nb_output_bins))
        x_other = F.reshape(
            x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins))

        # apply output scale and shift
        x_bass *= self.output_scale_bass
        x_drums *= self.output_scale_drums
        x_vocals *= self.output_scale_vocals
        x_other *= self.output_scale_other

        x_bass += self.output_mean_bass
        x_drums += self.output_mean_drums
        x_vocals += self.output_mean_vocals
        x_other += self.output_mean_other

        # since our output is non-negative, we can apply RELU
        mask_bass = F.relu(x_bass)
        mask_drums = F.relu(x_drums)
        mask_vocals = F.relu(x_vocals)
        mask_other = F.relu(x_other)

        # (Frames, Bsize, Channels, Fbins)
        x_bass = mask_bass * mix_spec
        x_drums = mask_drums * mix_spec
        x_vocals = mask_vocals * mix_spec
        x_other = mask_other * mix_spec

        if not self.is_predict:
            tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0)
            # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames)
            tmp = F.transpose(tmp, (0, 2, 3, 4, 1))
            pred_r, pred_i = [], []
            for i in range(tmp.shape[0]):
                pred_r.append(tmp[i] * F.cos(x_theta))
                pred_i.append(tmp[i] * F.sin(x_theta))
            pred_r = F.stack(*pred_r, axis=0)
            pred_i = F.stack(*pred_i, axis=0)
            pred_r = F.reshape(
                pred_r,
                (4 * nb_samples * nb_channels, self.nb_output_bins, nb_frames))
            pred_i = F.reshape(
                pred_i,
                (4 * nb_samples * nb_channels, self.nb_output_bins, nb_frames))
            pred = F.istft(pred_r,
                           pred_i,
                           self.n_fft,
                           self.n_hop,
                           self.n_fft,
                           window_type='hanning',
                           center=True)
            pred = F.reshape(pred, (4, nb_samples, nb_channels, -1))
        else:
            pred = None

        return mix_spec, F.concatenate(mask_bass,
                                       mask_drums,
                                       mask_vocals,
                                       mask_other,
                                       axis=2), pred