Ejemplo n.º 1
0
def test_dprnn_separator_forward_backward_real(
    input_dim,
    rnn_type,
    bidirectional,
    layer,
    unit,
    dropout,
    num_spk,
    nonlinear,
    segment_size,
):
    model = DPRNNSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        bidirectional=bidirectional,
        num_spk=num_spk,
        nonlinear=nonlinear,
        layer=layer,
        unit=unit,
        segment_size=segment_size,
        dropout=dropout,
    )
    model.train()

    x = torch.rand(2, 10, input_dim)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    maksed, flens, others = model(x, ilens=x_lens)

    assert isinstance(maksed[0], Tensor)
    assert len(maksed) == num_spk

    maksed[0].abs().mean().backward()
Ejemplo n.º 2
0
def test_dprnn_separator_invalid_type():
    with pytest.raises(ValueError):
        DPRNNSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            nonlinear="fff",
            segment_size=2,
        )
Ejemplo n.º 3
0
def test_dprnn_separator_output():

    x = torch.rand(2, 10, 10)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DPRNNSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            nonlinear="relu",
            segment_size=2,
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        assert x.shape == specs[0].shape
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
Ejemplo n.º 4
0
    stride=16,
)

conv_decoder = ConvDecoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

rnn_separator = RNNSeparator(
    input_dim=15,
    layer=1,
    unit=10,
)

dprnn_separator = DPRNNSeparator(input_dim=15, layer=1, unit=10, segment_size=4)

tcn_separator = TCNSeparator(
    input_dim=15,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)

transformer_separator = TransformerSeparator(
    input_dim=15,
    adim=8,
    aheads=2,
    layers=2,
Ejemplo n.º 5
0
null_decoder = NullDecoder()

conformer_separator = ConformerSeparator(
    input_dim=17, predict_noise=True, adim=8, aheads=2, layers=2, linear_units=10
)

dc_crn_separator = DC_CRNSeparator(
    input_dim=17, predict_noise=True, input_channels=[2, 2, 4]
)

dccrn_separator = DCCRNSeparator(
    input_dim=17, num_spk=1, kernel_num=[32, 64, 128], use_noise_mask=True
)

dprnn_separator = DPRNNSeparator(
    input_dim=17, predict_noise=True, layer=1, unit=10, segment_size=4
)

dptnet_separator = DPTNetSeparator(
    input_dim=16, predict_noise=True, layer=1, unit=10, segment_size=4
)

rnn_separator = RNNSeparator(input_dim=17, predict_noise=True, layer=1, unit=10)

svoice_separator = SVoiceSeparator(
    input_dim=17,
    enc_dim=4,
    kernel_size=4,
    hidden_size=4,
    num_spk=2,
    num_layers=2,