Ejemplo n.º 1
0
def test_svoice_separator_forward_backward(
    input_dim,
    enc_dim,
    kernel_size,
    hidden_size,
    num_spk,
    num_layers,
    segment_size,
    bidirectional,
    input_normalize,
):
    model = SVoiceSeparator(
        input_dim=input_dim,
        enc_dim=enc_dim,
        kernel_size=kernel_size,
        hidden_size=hidden_size,
        num_spk=num_spk,
        num_layers=num_layers,
        segment_size=segment_size,
        bidirectional=bidirectional,
        input_normalize=input_normalize,
    )
    model.train()

    x = torch.rand(2, 800)
    x_lens = torch.tensor([400, 300], dtype=torch.long)

    separated, _, _ = model(x, ilens=x_lens)

    assert isinstance(separated[0][0], Tensor)
    assert len(separated) == num_layers

    separated[0][0].mean().backward()
Ejemplo n.º 2
0
def test_svoice_separator_output_eval():
    x = torch.rand(2, 800)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = SVoiceSeparator(
            input_dim=12,
            enc_dim=8,
            kernel_size=8,
            hidden_size=8,
            num_spk=num_spk,
            num_layers=4,
            segment_size=2,
            bidirectional=False,
            input_normalize=False,
        )
        model.eval()
        waveforms, _, _ = model(x, x_lens)
        assert isinstance(waveforms, list)
        assert x[0].shape == waveforms[0][0].shape
Ejemplo n.º 3
0
                                 unit=10,
                                 segment_size=4)

dptnet_separator = DPTNetSeparator(input_dim=16,
                                   layer=1,
                                   unit=10,
                                   segment_size=4)

rnn_separator = RNNSeparator(input_dim=17, layer=1, unit=10)

svoice_separator = SVoiceSeparator(
    input_dim=17,
    enc_dim=4,
    kernel_size=4,
    hidden_size=4,
    num_spk=2,
    num_layers=2,
    segment_size=4,
    bidirectional=False,
    input_normalize=False,
)

tcn_separator = TCNSeparator(
    input_dim=17,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)