def test_PITSolver_forward(num_spk): batch = 2 inf = [torch.rand(batch, 10, 100) for spk in range(num_spk)] ref = [inf[num_spk - spk - 1] for spk in range(num_spk)] # reverse inf as ref solver = PITSolver(FrequencyDomainL1(), independent_perm=True) loss, stats, others = solver(ref, inf) perm = others["perm"] correct_perm = list(range(num_spk)) correct_perm.reverse() assert perm[0].equal(torch.tensor(correct_perm)) # test for independent_perm is False solver = PITSolver(FrequencyDomainL1(), independent_perm=False) loss, stats, others = solver(ref, inf, {"perm": perm})
def test_PITSolver_forward(num_spk): batch = 2 inf = [torch.rand(batch, 10, 100) for spk in range(num_spk)] ref = [inf[num_spk - spk - 1] for spk in range(num_spk)] # reverse inf as ref solver = FixedOrderSolver(FrequencyDomainL1()) loss, stats, others = solver(ref, inf)
def test_MultiLayerPITSolver_forward_multi_layer(num_spk): batch = 2 num_layers = 2 # infs is a List of List (num_layer x num_speaker Tensors) infs = [ [torch.rand(batch, 10, 100) for spk in range(num_spk)] for _ in range(num_layers) ] ref = [infs[-1][num_spk - spk - 1] for spk in range(num_spk)] # reverse inf as ref solver = MultiLayerPITSolver(FrequencyDomainL1(), independent_perm=True) loss, stats, others = solver(ref, infs) perm = others["perm"] correct_perm = list(range(num_spk)) correct_perm.reverse() assert perm[0].equal(torch.tensor(correct_perm)) # test for independent_perm is False solver = MultiLayerPITSolver(FrequencyDomainL1(), independent_perm=False) loss, stats, others = solver(ref, infs, {"perm": perm})
masking_mode="E", use_clstm=True, bidirectional=False, use_cbn=False, kernel_size=5, kernel_num=[ 32, 64, 128, ], use_builtin_complex=True, use_noise_mask=False, ) si_snr_loss = SISNRLoss() tf_mse_loss = FrequencyDomainMSE() tf_l1_loss = FrequencyDomainL1() pit_wrapper = PITSolver(criterion=si_snr_loss) fix_order_solver = FixedOrderSolver(criterion=tf_mse_loss) @pytest.mark.parametrize( "encoder, decoder", [ (stft_encoder, stft_decoder), (stft_encoder_bultin_complex, stft_decoder), (conv_encoder, conv_decoder), ], ) @pytest.mark.parametrize( "separator",