예제 #1
0
def test_dc_crn_separator_invalid_type():
    with pytest.raises(ValueError):
        DC_CRNSeparator(
            input_dim=17,
            input_channels=[2, 2, 4],
            mode="xxx",
        )
예제 #2
0
def test_dc_crn_separator_invalid_enc_layer():
    with pytest.raises(AssertionError):
        DC_CRNSeparator(
            input_dim=17,
            input_channels=[2, 2, 4],
            enc_layers=1,
        )
예제 #3
0
def test_dc_crn_separator_forward_backward_complex(
    input_dim,
    num_spk,
    input_channels,
    enc_hid_channels,
    enc_layers,
    glstm_groups,
    glstm_layers,
    glstm_bidirectional,
    glstm_rearrange,
    mode,
):
    model = DC_CRNSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        input_channels=input_channels,
        enc_hid_channels=enc_hid_channels,
        enc_kernel_size=(1, 3),
        enc_padding=(0, 1),
        enc_last_kernel_size=(1, 3),
        enc_last_stride=(1, 2),
        enc_last_padding=(0, 1),
        enc_layers=enc_layers,
        skip_last_kernel_size=(1, 3),
        skip_last_stride=(1, 1),
        skip_last_padding=(0, 1),
        glstm_groups=glstm_groups,
        glstm_layers=glstm_layers,
        glstm_bidirectional=glstm_bidirectional,
        glstm_rearrange=glstm_rearrange,
        mode=mode,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert is_complex(masked[0])
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
예제 #4
0
def test_dc_crn_separator_output():
    real = torch.rand(2, 10, 17)
    imag = torch.rand(2, 10, 17)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DC_CRNSeparator(
            input_dim=17,
            num_spk=num_spk,
            input_channels=[2, 2, 4],
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
예제 #5
0
def test_dc_crn_separator_multich_input(
    num_spk,
    input_channels,
    enc_kernel_size,
    enc_padding,
    enc_last_kernel_size,
    enc_last_stride,
    enc_last_padding,
    skip_last_kernel_size,
    skip_last_stride,
    skip_last_padding,
):
    model = DC_CRNSeparator(
        input_dim=33,
        num_spk=num_spk,
        input_channels=input_channels,
        enc_hid_channels=2,
        enc_kernel_size=enc_kernel_size,
        enc_padding=enc_padding,
        enc_last_kernel_size=enc_last_kernel_size,
        enc_last_stride=enc_last_stride,
        enc_last_padding=enc_last_padding,
        enc_layers=3,
        skip_last_kernel_size=skip_last_kernel_size,
        skip_last_stride=skip_last_stride,
        skip_last_padding=skip_last_padding,
    )
    model.train()

    real = torch.rand(2, 10, input_channels[0] // 2, 33)
    imag = torch.rand(2, 10, input_channels[0] // 2, 33)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert is_complex(masked[0])
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
예제 #6
0
stft_encoder_bultin_complex = STFTEncoder(n_fft=32,
                                          hop_length=16,
                                          use_builtin_complex=True)

stft_decoder = STFTDecoder(n_fft=32, hop_length=16)

conv_encoder = ConvEncoder(channel=17, kernel_size=36, stride=18)

conv_decoder = ConvDecoder(channel=17, kernel_size=36, stride=18)

null_encoder = NullEncoder()

null_decoder = NullDecoder()

dc_crn_separator = DC_CRNSeparator(input_dim=17, input_channels=[2, 2, 4])

dccrn_separator = DCCRNSeparator(input_dim=17,
                                 num_spk=1,
                                 kernel_num=[32, 64, 128])

dprnn_separator = DPRNNSeparator(input_dim=17,
                                 layer=1,
                                 unit=10,
                                 segment_size=4)

dptnet_separator = DPTNetSeparator(input_dim=16,
                                   layer=1,
                                   unit=10,
                                   segment_size=4)
예제 #7
0
stft_decoder = STFTDecoder(n_fft=32, hop_length=16)

conv_encoder = ConvEncoder(channel=17, kernel_size=36, stride=18)

conv_decoder = ConvDecoder(channel=17, kernel_size=36, stride=18)

null_encoder = NullEncoder()

null_decoder = NullDecoder()

conformer_separator = ConformerSeparator(
    input_dim=17, predict_noise=True, adim=8, aheads=2, layers=2, linear_units=10
)

dc_crn_separator = DC_CRNSeparator(
    input_dim=17, predict_noise=True, input_channels=[2, 2, 4]
)

dccrn_separator = DCCRNSeparator(
    input_dim=17, num_spk=1, kernel_num=[32, 64, 128], use_noise_mask=True
)

dprnn_separator = DPRNNSeparator(
    input_dim=17, predict_noise=True, layer=1, unit=10, segment_size=4
)

dptnet_separator = DPTNetSeparator(
    input_dim=16, predict_noise=True, layer=1, unit=10, segment_size=4
)

rnn_separator = RNNSeparator(input_dim=17, predict_noise=True, layer=1, unit=10)