示例#1
0
def test_istft_double_backward(ctx, seed, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward):
    backend = ctx.backend[0].split(":")[0]
    if backend == 'cuda':
        pytest.skip('CUDA Convolution N-D is only supported in CUDNN extension')

    if not as_stft_backward:
        if pad_mode != "constant":
            pytest.skip(
                '`pad_mode != "constant"` is only for `as_stft_backward == True`')

    from nbla_test_utils import backward_function_tester
    rng = np.random.RandomState(seed)

    # Generate istft inputs by calling stft
    x_shape = create_stft_input_shape(window_size)
    stft_input = rng.randn(*x_shape).astype(np.float32)
    y_r, y_i = ref_stft(stft_input, window_size, stride,
                        fft_size, window_type, center, pad_mode, False)
    istft_inputs = [y_r, y_i]

    if not as_stft_backward:
        # Skip for NOLA condition violation
        length = x_shape[1]
        if is_nola_violation(window_type, window_size, stride, fft_size, length, center):
            pytest.skip('NOLA condition violation.')

    rng = np.random.RandomState(seed)
    func_args = [window_size, stride, fft_size,
                 window_type, center, pad_mode, as_stft_backward]
    backward_function_tester(rng, F.istft,
                             inputs=istft_inputs,
                             func_args=func_args,
                             ctx=ctx,
                             atol_accum=6e-2)
示例#2
0
def test_stft_istft_identity(ctx, window_size, stride, fft_size, window_type,
                             center, pad_mode):
    backend = ctx.backend[0].split(":")[0]
    if backend == 'cuda':
        pytest.skip(
            'CUDA Convolution N-D is only supported in CUDNN extension')

    x_shape = create_stft_input_shape(window_size)
    x = np.random.randn(*x_shape)

    # Skip for NOLA condition violation
    length = x_shape[1]
    if is_nola_violation(window_type, window_size, stride, fft_size, length,
                         center):
        pytest.skip('NOLA condition violation.')
        return

    x = nn.Variable.from_numpy_array(x)
    with nn.context_scope(ctx):
        yr, yi = F.stft(x, window_size, stride, fft_size, window_type, center,
                        pad_mode)
        z = F.istft(yr,
                    yi,
                    window_size,
                    stride,
                    fft_size,
                    window_type,
                    center,
                    pad_mode="constant")
    z.forward()

    assert (np.allclose(x.d, z.d, atol=1e-5, rtol=1e-5))
示例#3
0
def test_istft_forward_backward(ctx, seed, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward):
    backend = ctx.backend[0].split(":")[0]
    if backend == 'cuda':
        pytest.skip('CUDA Convolution N-D is only supported in CUDNN extension')

    if not as_stft_backward:
        if pad_mode != "constant":
            pytest.skip(
                '`pad_mode != "constant"` is only for `as_stft_backward == True`')

    func_name = "ISTFTCuda" if backend == 'cudnn' else "ISTFT"

    from nbla_test_utils import function_tester
    rng = np.random.RandomState(seed)

    # Generate istft inputs by calling stft
    x_shape = create_stft_input_shape(window_size)
    stft_input = rng.randn(*x_shape).astype(np.float32)
    y_r, y_i = ref_stft(stft_input, window_size, stride,
                        fft_size, window_type, center, pad_mode, False)
    istft_inputs = [y_r, y_i]

    # Check violation of NOLA condition
    if not as_stft_backward:
        length = x_shape[1]
        if is_nola_violation(window_type, window_size, stride, fft_size, length, center):
            check_nola_violation(
                y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward)
            return

    function_tester(rng, F.istft, ref_istft, istft_inputs, func_args=[
                    window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward], ctx=ctx, func_name=func_name, atol_f=1e-5, atol_b=3e-2, dstep=1e-2)