def test_criterion_behavior(training): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=rnn_separator, decoder=stft_decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))], ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, } if training: with pytest.raises(AttributeError): loss, stats, weight = enh_model(**kwargs) else: loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior_noise(encoder, decoder, separator): if not isinstance(encoder, STFTEncoder) and isinstance( separator, (DCCRNSeparator, DC_CRNSeparator) ): # skip because DCCRNSeparator and DC_CRNSeparator only work # for complex spectrum features return inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()] noise_ref = torch.randn(2, 300) enh_model = ESPnetEnhancementModel( encoder=encoder, separator=separator, decoder=decoder, mask_module=None, loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)}, "noise_ref1": noise_ref, } loss, stats, weight = enh_model(**kwargs)
def test_criterion_behavior_dereverb(loss_type, num_spk): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)] dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)] beamformer = NeuralBeamformer( input_dim=17, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=False, ) if loss_type == "mask_mse": loss_wrapper = PITSolver( criterion=FrequencyDomainMSE( compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True ) ) else: loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True)) enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=beamformer, decoder=stft_decoder, mask_module=None, loss_wrappers=[loss_wrapper], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "dereverb_ref1": dereverb_ref[0], } loss, stats, weight = enh_model(**kwargs)
from typing import List from typing import Union from mir_eval.separation import bss_eval_sources import numpy as np from pystoi import stoi import torch from typeguard import check_argument_types from espnet.utils.cli_utils import get_commandline_args from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.fileio.datadir_writer import DatadirWriter from espnet2.fileio.sound_scp import SoundScpReader from espnet2.utils import config_argparse si_snr_loss = SISNRLoss() def scoring( output_dir: str, dtype: str, log_level: Union[int, str], key_file: str, ref_scp: List[str], inf_scp: List[str], ref_channel: int, ): assert check_argument_types() logging.basicConfig( level=log_level,