Exemplo n.º 1
0
def test_STFTDecoder_backward(n_fft, win_length, hop_length, window, center,
                              normalized, onesided):
    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    real = torch.rand(2,
                      300,
                      n_fft // 2 + 1 if onesided else n_fft,
                      requires_grad=True)
    imag = torch.rand(2,
                      300,
                      n_fft // 2 + 1 if onesided else n_fft,
                      requires_grad=True)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([300 * hop_length, 295 * hop_length],
                          dtype=torch.long)
    y, ilens = decoder(x, x_lens)
    y.sum().backward()
Exemplo n.º 2
0
def test_STFTDecoder_invalid_type(n_fft, win_length, hop_length, window,
                                  center, normalized, onesided):
    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )
    with pytest.raises(TypeError):
        real = torch.rand(2,
                          300,
                          n_fft // 2 + 1 if onesided else n_fft,
                          requires_grad=True)
        x_lens = torch.tensor([300 * hop_length, 295 * hop_length],
                              dtype=torch.long)
        y, ilens = decoder(real, x_lens)
Exemplo n.º 3
0
def test_STFTDecoder_invalid_type(
    n_fft, win_length, hop_length, window, center, normalized, onesided
):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")

    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )
    with pytest.raises(TypeError):
        real = torch.rand(
            2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True
        )
        x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long)
        y, ilens = decoder(real, x_lens)
Exemplo n.º 4
0
def test_STFTDecoder_backward(
    n_fft, win_length, hop_length, window, center, normalized, onesided
):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")
    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True)
    imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long)
    y, ilens = decoder(x, x_lens)
    y.sum().backward()
Exemplo n.º 5
0
from espnet2.enh.encoder.conv_encoder import ConvEncoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.espnet_enh_s2t_model import ESPnetEnhS2TModel
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.enh.loss.criterions.time_domain import SISNRLoss
from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.layers.label_aggregation import LabelAggregate

enh_stft_encoder = STFTEncoder(
    n_fft=32,
    hop_length=16,
)

enh_stft_decoder = STFTDecoder(
    n_fft=32,
    hop_length=16,
)

enh_rnn_separator = RNNSeparator(
    input_dim=17,
    layer=1,
    unit=10,
    num_spk=1,
)

si_snr_loss = SISNRLoss()

fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
Exemplo n.º 6
0
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)
Exemplo n.º 7
0
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")


stft_encoder = STFTEncoder(
    n_fft=28,
    hop_length=16,
)

stft_encoder_bultin_complex = STFTEncoder(
    n_fft=28,
    hop_length=16,
    use_builtin_complex=True,
)

stft_decoder = STFTDecoder(
    n_fft=28,
    hop_length=16,
)

conv_encoder = ConvEncoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

conv_decoder = ConvDecoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

rnn_separator = RNNSeparator(
Exemplo n.º 8
0
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")


stft_encoder = STFTEncoder(
    n_fft=16,
    hop_length=8,
)

stft_encoder_bultin_complex = STFTEncoder(
    n_fft=16,
    hop_length=8,
    use_builtin_complex=True,
)

stft_decoder = STFTDecoder(
    n_fft=16,
    hop_length=8,
)

conv_encoder = ConvEncoder(
    channel=9,
    kernel_size=20,
    stride=10,
)

conv_decoder = ConvDecoder(
    channel=9,
    kernel_size=20,
    stride=10,
)

rnn_separator = RNNSeparator(