def test_neural_beamformer_wpe_output(ch, num_spk, multi_source_wpe, use_dnn_mask_for_wpe): torch.random.manual_seed(0) inputs = torch.randn(2, 16, ch) if ch > 1 else torch.randn(2, 16) inputs = inputs.float() ilens = torch.LongTensor([16, 12]) stft = STFTEncoder(n_fft=8, hop_length=2) model = NeuralBeamformer( stft.output_dim, num_spk=num_spk, use_wpe=True, use_dnn_mask_for_wpe=use_dnn_mask_for_wpe, multi_source_wpe=multi_source_wpe, wlayers=2, wunits=2, wprojs=2, taps=5, delay=3, use_beamformer=False, ) model.eval() input_spectrum, flens = stft(inputs, ilens) specs, _, others = model(input_spectrum, flens) assert isinstance(specs, list) if not use_dnn_mask_for_wpe or multi_source_wpe: assert len(specs) == 1 else: assert len(specs) == num_spk assert specs[0].shape == input_spectrum.shape assert specs[0].dtype == torch.float assert isinstance(others, dict) if use_dnn_mask_for_wpe: assert "mask_dereverb1" in others, others.keys() assert others["mask_dereverb1"].shape == specs[0].shape
def test_neural_beamformer_bf_output( num_spk, use_noise_mask, beamformer_type, diagonal_loading, mask_flooring, use_torch_solver, ): if num_spk == 1 and beamformer_type in ( "lcmv", "lcmp", "wlcmp", "mvdr_tfs", "mvdr_tfs_souden", ): # only support multiple-source cases return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) torch.random.manual_seed(0) stft = STFTEncoder(n_fft=8, hop_length=2) model = NeuralBeamformer( stft.output_dim, num_spk=num_spk, use_wpe=False, taps=2, delay=3, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, use_noise_mask=use_noise_mask, beamformer_type=beamformer_type, diagonal_loading=diagonal_loading, mask_flooring=mask_flooring, use_torch_solver=use_torch_solver, ) model.eval() input_spectrum, flens = stft(inputs, ilens) specs, _, others = model(input_spectrum, flens) assert isinstance(others, dict) if use_noise_mask: assert "mask_noise1" in others assert others["mask_noise1"].shape == others["mask_spk1"].shape assert isinstance(specs, list) assert len(specs) == num_spk for n in range(1, num_spk + 1): assert "mask_spk{}".format(n) in others, others.keys() assert others["mask_spk{}".format(n)].shape[-2] == ch assert specs[n - 1].shape == others["mask_spk{}".format(n)][..., 0, :].shape assert specs[n - 1].shape == input_spectrum[..., 0, :].shape if is_torch_1_9_plus and torch.is_complex(specs[n - 1]): assert specs[n - 1].dtype == torch.complex64 else: assert specs[n - 1].dtype == torch.float
def test_neural_beamformer_bf_output(num_spk, use_noise_mask, beamformer_type): ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) torch.random.manual_seed(0) stft = STFTEncoder(n_fft=8, hop_length=2) model = NeuralBeamformer( stft.output_dim, num_spk=num_spk, use_wpe=False, taps=2, delay=3, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, use_noise_mask=use_noise_mask, beamformer_type=beamformer_type, ) model.eval() input_spectrum, flens = stft(inputs, ilens) specs, _, others = model(input_spectrum, flens) assert isinstance(others, dict) if use_noise_mask: assert "mask_noise1" in others assert others["mask_noise1"].shape == others["mask_spk1"].shape assert isinstance(specs, list) assert len(specs) == num_spk for n in range(1, num_spk + 1): assert "mask_spk{}".format(n) in others, others.keys() assert others["mask_spk{}".format(n)].shape[-2] == ch assert specs[n - 1].shape == others["mask_spk{}".format(n)][..., 0, :].shape assert specs[n - 1].shape == input_spectrum[..., 0, :].shape assert specs[n - 1].dtype == torch.float
def test_criterion_behavior_dereverb(loss_type, num_spk): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)] dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)] beamformer = NeuralBeamformer( input_dim=17, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=False, ) if loss_type == "mask_mse": loss_wrapper = PITSolver( criterion=FrequencyDomainMSE( compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True ) ) else: loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True)) enh_model = ESPnetEnhancementModel( encoder=stft_encoder, separator=beamformer, decoder=stft_decoder, mask_module=None, loss_wrappers=[loss_wrapper], ) enh_model.train() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "dereverb_ref1": dereverb_ref[0], } loss, stats, weight = enh_model(**kwargs)
def test_neural_beamformer_forward_backward( n_fft, win_length, hop_length, num_spk, loss_type, use_wpe, wnet_type, wlayers, wunits, wprojs, taps, delay, use_dnn_mask_for_wpe, multi_source_wpe, use_beamformer, bnet_type, blayers, bunits, bprojs, badim, ref_channel, use_noise_mask, bnonlinear, beamformer_type, ): # Skip some cases if num_spk > 1 and use_wpe and use_beamformer: if not multi_source_wpe: # Single-source WPE is not supported with beamformer in multi-speaker cases return elif num_spk == 1: if multi_source_wpe: # When num_spk == 1, `multi_source_wpe` has no effect return elif beamformer_type in ( "lcmv", "lcmp", "wlcmp", "mvdr_tfs", "mvdr_tfs_souden", ): # only support multiple-source cases return if bnonlinear != "sigmoid" and ( beamformer_type != "mvdr_souden" or multi_source_wpe ): # only test different nonlinear layers with MVDR_Souden return # ensures reproducibility and reversibility in the matrix inverse computation torch.random.manual_seed(0) stft = STFTEncoder(n_fft=n_fft, win_length=win_length, hop_length=hop_length) model = NeuralBeamformer( stft.output_dim, num_spk=num_spk, loss_type=loss_type, use_wpe=use_wpe, wnet_type=wnet_type, wlayers=wlayers, wunits=wunits, wprojs=wprojs, taps=taps, delay=delay, use_dnn_mask_for_wpe=use_dnn_mask_for_wpe, use_beamformer=use_beamformer, bnet_type=bnet_type, blayers=blayers, bunits=bunits, bprojs=bprojs, badim=badim, ref_channel=ref_channel, use_noise_mask=use_noise_mask, beamformer_type=beamformer_type, rtf_iterations=2, shared_power=True, ) model.train() inputs = random_speech[..., :2].float() ilens = torch.LongTensor([16, 12]) input_spectrum, flens = stft(inputs, ilens) est_speech, flens, others = model(input_spectrum, flens) if loss_type.startswith("mask"): assert est_speech is None loss = sum([abs(m).mean() for m in others.values()]) else: loss = sum([abs(est).mean() for est in est_speech]) loss.backward()
def test_beamformer_net_invalid_loss_type(): with pytest.raises(ValueError): NeuralBeamformer(10, loss_type="fff")
def test_beamformer_net_invalid_bf_type(): with pytest.raises(ValueError): NeuralBeamformer(10, use_beamformer=True, beamformer_type="fff")
def test_forward_with_beamformer_net( training, mask_type, loss_type, num_spk, use_noise_mask, stft_consistency, use_builtin_complex, ): # Skip some testing cases if not loss_type.startswith("mask") and mask_type != "IBM": # `mask_type` has no effect when `loss_type` is not "mask..." return if not is_torch_1_9_plus and use_builtin_complex: # builtin complex support is only available in PyTorch 1.8+ return ch = 2 inputs = random_speech[..., :ch].float() ilens = torch.LongTensor([16, 12]) speech_refs = [torch.randn(2, 16, ch).float() for spk in range(num_spk)] noise_ref1 = torch.randn(2, 16, ch, dtype=torch.float) dereverb_ref1 = torch.randn(2, 16, ch, dtype=torch.float) encoder = STFTEncoder( n_fft=8, hop_length=2, use_builtin_complex=use_builtin_complex ) decoder = STFTDecoder(n_fft=8, hop_length=2) if stft_consistency and loss_type in ("mask_mse", "snr", "si_snr", "ci_sdr"): # skip this condition return beamformer = NeuralBeamformer( input_dim=5, loss_type=loss_type, num_spk=num_spk, use_wpe=True, wlayers=2, wunits=2, wprojs=2, use_dnn_mask_for_wpe=True, multi_source_wpe=True, use_beamformer=True, blayers=2, bunits=2, bprojs=2, badim=2, ref_channel=0, use_noise_mask=use_noise_mask, beamformer_type="mvdr_souden", ) enh_model = ESPnetEnhancementModel( encoder=encoder, decoder=decoder, separator=beamformer, stft_consistency=stft_consistency, loss_type=loss_type, mask_type=mask_type, ) if training: enh_model.train() else: enh_model.eval() kwargs = { "speech_mix": inputs, "speech_mix_lengths": ilens, **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)}, "noise_ref1": noise_ref1, "dereverb_ref1": dereverb_ref1, } loss, stats, weight = enh_model(**kwargs)