Example #1
0
def test_PITSolver_forward(num_spk):

    batch = 2
    inf = [torch.rand(batch, 10, 100) for spk in range(num_spk)]
    ref = [inf[num_spk - spk - 1]
           for spk in range(num_spk)]  # reverse inf as ref
    solver = PITSolver(FrequencyDomainL1(), independent_perm=True)

    loss, stats, others = solver(ref, inf)
    perm = others["perm"]
    correct_perm = list(range(num_spk))
    correct_perm.reverse()
    assert perm[0].equal(torch.tensor(correct_perm))

    # test for independent_perm is False

    solver = PITSolver(FrequencyDomainL1(), independent_perm=False)
    loss, stats, others = solver(ref, inf, {"perm": perm})
Example #2
0
def test_PITSolver_forward(num_spk):

    batch = 2
    inf = [torch.rand(batch, 10, 100) for spk in range(num_spk)]
    ref = [inf[num_spk - spk - 1]
           for spk in range(num_spk)]  # reverse inf as ref
    solver = FixedOrderSolver(FrequencyDomainL1())

    loss, stats, others = solver(ref, inf)
Example #3
0
def test_MultiLayerPITSolver_forward_multi_layer(num_spk):

    batch = 2
    num_layers = 2
    # infs is a List of List (num_layer x num_speaker Tensors)
    infs = [
        [torch.rand(batch, 10, 100) for spk in range(num_spk)]
        for _ in range(num_layers)
    ]
    ref = [infs[-1][num_spk - spk - 1] for spk in range(num_spk)]  # reverse inf as ref
    solver = MultiLayerPITSolver(FrequencyDomainL1(), independent_perm=True)

    loss, stats, others = solver(ref, infs)
    perm = others["perm"]
    correct_perm = list(range(num_spk))
    correct_perm.reverse()
    assert perm[0].equal(torch.tensor(correct_perm))

    # test for independent_perm is False

    solver = MultiLayerPITSolver(FrequencyDomainL1(), independent_perm=False)
    loss, stats, others = solver(ref, infs, {"perm": perm})
Example #4
0
    masking_mode="E",
    use_clstm=True,
    bidirectional=False,
    use_cbn=False,
    kernel_size=5,
    kernel_num=[
        32,
        64,
        128,
    ],
    use_builtin_complex=True,
    use_noise_mask=False,
)
si_snr_loss = SISNRLoss()
tf_mse_loss = FrequencyDomainMSE()
tf_l1_loss = FrequencyDomainL1()

pit_wrapper = PITSolver(criterion=si_snr_loss)
fix_order_solver = FixedOrderSolver(criterion=tf_mse_loss)


@pytest.mark.parametrize(
    "encoder, decoder",
    [
        (stft_encoder, stft_decoder),
        (stft_encoder_bultin_complex, stft_decoder),
        (conv_encoder, conv_decoder),
    ],
)
@pytest.mark.parametrize(
    "separator",