Exemple #1
0
def test_neural_beamformer_wpe_output(ch, num_spk, multi_source_wpe,
                                      use_dnn_mask_for_wpe):
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16)
    inputs = inputs.float()
    ilens = torch.LongTensor([16, 12])
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=True,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        multi_source_wpe=multi_source_wpe,
        wlayers=2,
        wunits=2,
        wprojs=2,
        taps=5,
        delay=3,
        use_beamformer=False,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(specs, list)
    if not use_dnn_mask_for_wpe or multi_source_wpe:
        assert len(specs) == 1
    else:
        assert len(specs) == num_spk
    assert specs[0].shape == input_spectrum.shape
    assert specs[0].dtype == torch.float
    assert isinstance(others, dict)
    if use_dnn_mask_for_wpe:
        assert "mask_dereverb1" in others, others.keys()
        assert others["mask_dereverb1"].shape == specs[0].shape
Exemple #2
0
def test_neural_beamformer_bf_output(
    num_spk,
    use_noise_mask,
    beamformer_type,
    diagonal_loading,
    mask_flooring,
    use_torch_solver,
):
    if num_spk == 1 and beamformer_type in (
        "lcmv",
        "lcmp",
        "wlcmp",
        "mvdr_tfs",
        "mvdr_tfs_souden",
    ):
        # only support multiple-source cases
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        diagonal_loading=diagonal_loading,
        mask_flooring=mask_flooring,
        use_torch_solver=use_torch_solver,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][..., 0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        if is_torch_1_9_plus and torch.is_complex(specs[n - 1]):
            assert specs[n - 1].dtype == torch.complex64
        else:
            assert specs[n - 1].dtype == torch.float
def test_STFTEncoder_backward(n_fft, win_length, hop_length, window, center,
                              normalized, onesided):
    encoder = STFTEncoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    x = torch.rand(2, 32000, requires_grad=True)
    x_lens = torch.tensor([32000, 30000], dtype=torch.long)
    y, flens = encoder(x, x_lens)
    y.abs().sum().backward()
Exemple #4
0
def test_STFTEncoder_backward(n_fft, win_length, hop_length, window, center,
                              normalized, onesided):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")

    encoder = STFTEncoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    x = torch.rand(2, 32000, requires_grad=True)
    x_lens = torch.tensor([32000, 30000], dtype=torch.long)
    y, flens = encoder(x, x_lens)
    y.abs().sum().backward()
Exemple #5
0
def test_neural_beamformer_bf_output(num_spk, use_noise_mask, beamformer_type):
    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])

    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=8, hop_length=2)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        use_wpe=False,
        taps=2,
        delay=3,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
    )
    model.eval()
    input_spectrum, flens = stft(inputs, ilens)
    specs, _, others = model(input_spectrum, flens)
    assert isinstance(others, dict)
    if use_noise_mask:
        assert "mask_noise1" in others
        assert others["mask_noise1"].shape == others["mask_spk1"].shape
    assert isinstance(specs, list)
    assert len(specs) == num_spk
    for n in range(1, num_spk + 1):
        assert "mask_spk{}".format(n) in others, others.keys()
        assert others["mask_spk{}".format(n)].shape[-2] == ch
        assert specs[n - 1].shape == others["mask_spk{}".format(n)][...,
                                                                    0, :].shape
        assert specs[n - 1].shape == input_spectrum[..., 0, :].shape
        assert specs[n - 1].dtype == torch.float
Exemple #6
0
def test_neural_beamformer_forward_backward(
    n_fft,
    win_length,
    hop_length,
    num_spk,
    loss_type,
    use_wpe,
    wnet_type,
    wlayers,
    wunits,
    wprojs,
    taps,
    delay,
    use_dnn_mask_for_wpe,
    multi_source_wpe,
    use_beamformer,
    bnet_type,
    blayers,
    bunits,
    bprojs,
    badim,
    ref_channel,
    use_noise_mask,
    bnonlinear,
    beamformer_type,
):
    # Skip some cases
    if num_spk > 1 and use_wpe and use_beamformer:
        if not multi_source_wpe:
            # Single-source WPE is not supported with beamformer in multi-speaker cases
            return
    elif num_spk == 1:
        if multi_source_wpe:
            # When num_spk == 1, `multi_source_wpe` has no effect
            return
        elif beamformer_type in (
            "lcmv",
            "lcmp",
            "wlcmp",
            "mvdr_tfs",
            "mvdr_tfs_souden",
        ):
            # only support multiple-source cases
            return
    if bnonlinear != "sigmoid" and (
        beamformer_type != "mvdr_souden" or multi_source_wpe
    ):
        # only test different nonlinear layers with MVDR_Souden
        return

    # ensures reproducibility and reversibility in the matrix inverse computation
    torch.random.manual_seed(0)
    stft = STFTEncoder(n_fft=n_fft, win_length=win_length, hop_length=hop_length)
    model = NeuralBeamformer(
        stft.output_dim,
        num_spk=num_spk,
        loss_type=loss_type,
        use_wpe=use_wpe,
        wnet_type=wnet_type,
        wlayers=wlayers,
        wunits=wunits,
        wprojs=wprojs,
        taps=taps,
        delay=delay,
        use_dnn_mask_for_wpe=use_dnn_mask_for_wpe,
        use_beamformer=use_beamformer,
        bnet_type=bnet_type,
        blayers=blayers,
        bunits=bunits,
        bprojs=bprojs,
        badim=badim,
        ref_channel=ref_channel,
        use_noise_mask=use_noise_mask,
        beamformer_type=beamformer_type,
        rtf_iterations=2,
        shared_power=True,
    )

    model.train()
    inputs = random_speech[..., :2].float()
    ilens = torch.LongTensor([16, 12])
    input_spectrum, flens = stft(inputs, ilens)
    est_speech, flens, others = model(input_spectrum, flens)
    if loss_type.startswith("mask"):
        assert est_speech is None
        loss = sum([abs(m).mean() for m in others.values()])
    else:
        loss = sum([abs(est).mean() for est in est_speech])
    loss.backward()
Exemple #7
0
def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
    metrics: List[str],
    frame_size: int = 512,
    frame_hop: int = 256,
):
    assert check_argument_types()
    for metric in metrics:
        assert metric in (
            "STOI",
            "ESTOI",
            "SNR",
            "SI_SNR",
            "SDR",
            "SAR",
            "SIR",
            "framewise-SNR",
        ), metric

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    assert len(ref_scp) == len(inf_scp), ref_scp
    num_spk = len(ref_scp)

    keys = [
        line.rstrip().split(maxsplit=1)[0]
        for line in open(key_file, encoding="utf-8")
    ]

    ref_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp
    ]
    inf_readers = [
        SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp
    ]

    # get sample rate
    fs, _ = ref_readers[0][keys[0]]

    # check keys
    for inf_reader, ref_reader in zip(inf_readers, ref_readers):
        assert inf_reader.keys() == ref_reader.keys()

    stft = STFTEncoder(n_fft=frame_size, hop_length=frame_hop)

    do_bss_eval = "SDR" in metrics or "SAR" in metrics or "SIR" in metrics
    with DatadirWriter(output_dir) as writer:
        for key in keys:
            ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
            inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
            ref = np.array(ref_audios)
            inf = np.array(inf_audios)
            if ref.ndim > inf.ndim:
                # multi-channel reference and single-channel output
                ref = ref[..., ref_channel]
                assert ref.shape == inf.shape, (ref.shape, inf.shape)
            elif ref.ndim < inf.ndim:
                # single-channel reference and multi-channel output
                raise ValueError("Reference must be multi-channel when the "
                                 "network output is multi-channel.")
            elif ref.ndim == inf.ndim == 3:
                # multi-channel reference and output
                ref = ref[..., ref_channel]
                inf = inf[..., ref_channel]

            if do_bss_eval or num_spk > 1:
                sdr, sir, sar, perm = bss_eval_sources(
                    ref, inf, compute_permutation=True)
            else:
                perm = [0]

            ilens = torch.LongTensor([ref.shape[1]])
            # (num_spk, T, F)
            ref_spec, flens = stft(torch.from_numpy(ref), ilens)
            inf_spec, _ = stft(torch.from_numpy(inf), ilens)

            for i in range(num_spk):
                p = int(perm[i])
                for metric in metrics:
                    name = f"{metric}_spk{i + 1}"
                    if metric == "STOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=False))
                    elif metric == "ESTOI":
                        writer[name][key] = str(
                            stoi(ref[i], inf[p], fs_sig=fs, extended=True))
                    elif metric == "SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SI_SNR":
                        si_snr_score = -float(
                            ESPnetEnhancementModel.si_snr_loss(
                                torch.from_numpy(ref[i][None, ...]),
                                torch.from_numpy(inf[p][None, ...]),
                            ))
                        writer[name][key] = str(si_snr_score)
                    elif metric == "SDR":
                        writer[name][key] = str(sdr[i])
                    elif metric == "SAR":
                        writer[name][key] = str(sar[i])
                    elif metric == "SIR":
                        writer[name][key] = str(sir[i])
                    elif metric == "framewise-SNR":
                        framewise_snr = -ESPnetEnhancementModel.snr_loss(
                            ref_spec[i].abs(), inf_spec[i].abs())
                        writer[name][key] = " ".join(
                            map(str, framewise_snr.tolist()))
                    else:
                        raise ValueError("Unsupported metric: %s" % metric)
                    # save permutation assigned script file
                    writer[f"wav_spk{i + 1}"][key] = inf_readers[
                        perm[i]].data[key]
Exemple #8
0
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.diar.layers.multi_mask import MultiMask
from espnet2.diar.separator.tcn_separator_nomask import TCNSeparatorNomask
from espnet2.enh.decoder.conv_decoder import ConvDecoder
from espnet2.enh.decoder.stft_decoder import STFTDecoder
from espnet2.enh.encoder.conv_encoder import ConvEncoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.espnet_enh_s2t_model import ESPnetEnhS2TModel
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.enh.loss.criterions.time_domain import SISNRLoss
from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.layers.label_aggregation import LabelAggregate

enh_stft_encoder = STFTEncoder(
    n_fft=32,
    hop_length=16,
)

enh_stft_decoder = STFTDecoder(
    n_fft=32,
    hop_length=16,
)

enh_rnn_separator = RNNSeparator(
    input_dim=17,
    layer=1,
    unit=10,
    num_spk=1,
)

si_snr_loss = SISNRLoss()
def test_forward_with_beamformer_net(
    training,
    mask_type,
    loss_type,
    num_spk,
    use_noise_mask,
    stft_consistency,
    use_builtin_complex,
):
    # Skip some testing cases
    if not loss_type.startswith("mask") and mask_type != "IBM":
        # `mask_type` has no effect when `loss_type` is not "mask..."
        return
    if not is_torch_1_9_plus and use_builtin_complex:
        # builtin complex support is only available in PyTorch 1.8+
        return

    ch = 2
    inputs = random_speech[..., :ch].float()
    ilens = torch.LongTensor([16, 12])
    speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)]
    noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float)
    encoder = STFTEncoder(
        n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex
    )
    decoder = STFTDecoder(n_fft=8, hop_length=2)

    if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"):
        # skip this condition
        return

    beamformer = NeuralBeamformer(
        input_dim=5,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=use_noise_mask,
        beamformer_type="mvdr_souden",
    )
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        decoder=decoder,
        separator=beamformer,
        stft_consistency=stft_consistency,
        loss_type=loss_type,
        mask_type=mask_type,
    )
    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "noise_ref1": noise_ref1,
        "dereverb_ref1": dereverb_ref1,
    }
    loss, stats, weight = enh_model(**kwargs)
from espnet2.enh.decoder.conv_decoder import ConvDecoder
from espnet2.enh.decoder.stft_decoder import STFTDecoder
from espnet2.enh.encoder.conv_encoder import ConvEncoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.enh.separator.dprnn_separator import DPRNNSeparator
from espnet2.enh.separator.neural_beamformer import NeuralBeamformer
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.transformer_separator import TransformerSeparator

is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")


stft_encoder = STFTEncoder(
    n_fft=28,
    hop_length=16,
)

stft_encoder_bultin_complex = STFTEncoder(
    n_fft=28,
    hop_length=16,
    use_builtin_complex=True,
)

stft_decoder = STFTDecoder(
    n_fft=28,
    hop_length=16,
)

conv_encoder = ConvEncoder(
    channel=15,
Exemple #11
0
from espnet2.enh.decoder.conv_decoder import ConvDecoder
from espnet2.enh.decoder.stft_decoder import STFTDecoder
from espnet2.enh.encoder.conv_encoder import ConvEncoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.enh.separator.dprnn_separator import DPRNNSeparator
from espnet2.enh.separator.neural_beamformer import NeuralBeamformer
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.transformer_separator import TransformerSeparator

is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0")

stft_encoder = STFTEncoder(
    n_fft=28,
    hop_length=16,
)

stft_decoder = STFTDecoder(
    n_fft=28,
    hop_length=16,
)

conv_encoder = ConvEncoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

conv_decoder = ConvDecoder(
    channel=15,