Ejemplo n.º 1
0
def test_tcn_separator_invalid_type():
    with pytest.raises(ValueError):
        TCNSeparator(
            input_dim=10,
            nonlinear="fff",
        )
    with pytest.raises(ValueError):
        TCNSeparator(
            input_dim=10,
            norm_type="xxx",
        )
Ejemplo n.º 2
0
def test_tcn_separator_forward_backward_real(
    input_dim,
    layer,
    num_spk,
    nonlinear,
    stack,
    bottleneck_dim,
    hidden_dim,
    kernel,
    causal,
    norm_type,
):
    model = TCNSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        layer=layer,
        stack=stack,
        bottleneck_dim=bottleneck_dim,
        hidden_dim=hidden_dim,
        kernel=kernel,
        causal=causal,
        norm_type=norm_type,
        nonlinear=nonlinear,
    )

    x = torch.rand(2, 10, input_dim)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    maksed, flens, others = model(x, ilens=x_lens)

    assert isinstance(maksed[0], Tensor)
    assert len(maksed) == num_spk

    maksed[0].abs().mean().backward()
Ejemplo n.º 3
0
def test_tcn_separator_output():

    x = torch.rand(2, 10, 10)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = TCNSeparator(
            input_dim=10,
            layer=num_spk,
            stack=2,
            bottleneck_dim=3,
            hidden_dim=3,
            kernel=3,
            causal=False,
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
Ejemplo n.º 4
0
    kernel_size=32,
    stride=16,
)

rnn_separator = RNNSeparator(
    input_dim=15,
    layer=1,
    unit=10,
)

dprnn_separator = DPRNNSeparator(input_dim=15, layer=1, unit=10, segment_size=4)

tcn_separator = TCNSeparator(
    input_dim=15,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)

transformer_separator = TransformerSeparator(
    input_dim=15,
    adim=8,
    aheads=2,
    layers=2,
    linear_units=10,
)


@pytest.mark.parametrize(
    "encoder, decoder",
Ejemplo n.º 5
0
    input_dim=17,
    enc_dim=4,
    kernel_size=4,
    hidden_size=4,
    num_spk=2,
    num_layers=2,
    segment_size=4,
    bidirectional=False,
    input_normalize=False,
)

tcn_separator = TCNSeparator(
    input_dim=17,
    predict_noise=True,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)

transformer_separator = TransformerSeparator(
    input_dim=17, predict_noise=True, adim=8, aheads=2, layers=2, linear_units=10
)

si_snr_loss = SISNRLoss()
tf_mse_loss = FrequencyDomainMSE()
tf_l1_loss = FrequencyDomainL1()

pit_wrapper = PITSolver(criterion=si_snr_loss)
multilayer_pit_solver = MultiLayerPITSolver(criterion=si_snr_loss)