예제 #1
0
def test_criterion_behavior(training):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=rnn_separator,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))],
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }

    if training:
        with pytest.raises(AttributeError):
            loss, stats, weight = enh_model(**kwargs)
    else:
        loss, stats, weight = enh_model(**kwargs)
예제 #2
0
def test_criterion_behavior_noise(encoder, decoder, separator):
    if not isinstance(encoder, STFTEncoder) and isinstance(
        separator, (DCCRNSeparator, DC_CRNSeparator)
    ):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    noise_ref = torch.randn(2, 300)
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
        "noise_ref1": noise_ref,
    }
    loss, stats, weight = enh_model(**kwargs)
예제 #3
0
def test_criterion_behavior_dereverb(loss_type, num_spk):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)]
    dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)]
    beamformer = NeuralBeamformer(
        input_dim=17,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=False,
    )
    if loss_type == "mask_mse":
        loss_wrapper = PITSolver(
            criterion=FrequencyDomainMSE(
                compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True
            )
        )
    else:
        loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True))
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=beamformer,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[loss_wrapper],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "dereverb_ref1": dereverb_ref[0],
    }
    loss, stats, weight = enh_model(**kwargs)
예제 #4
0
from typing import List
from typing import Union

from mir_eval.separation import bss_eval_sources
import numpy as np
from pystoi import stoi
import torch
from typeguard import check_argument_types

from espnet.utils.cli_utils import get_commandline_args
from espnet2.enh.loss.criterions.time_domain import SISNRLoss
from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.sound_scp import SoundScpReader
from espnet2.utils import config_argparse

si_snr_loss = SISNRLoss()


def scoring(
    output_dir: str,
    dtype: str,
    log_level: Union[int, str],
    key_file: str,
    ref_scp: List[str],
    inf_scp: List[str],
    ref_channel: int,
):
    assert check_argument_types()

    logging.basicConfig(
        level=log_level,