コード例 #1
0
def test_rnn_separator_forward_backward_complex(input_dim, rnn_type, layer,
                                                unit, dropout, num_spk,
                                                nonlinear):
    model = RNNSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
コード例 #2
0
def test_rnn_separator_invalid_type():
    with pytest.raises(ValueError):
        RNNSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=2,
            nonlinear="fff",
        )
コード例 #3
0
def test_rnn_separator_output():

    x = torch.rand(2, 10, 10)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = RNNSeparator(
            input_dim=10,
            rnn_type="rnn",
            layer=2,
            unit=10,
            dropout=0.1,
            num_spk=num_spk,
            nonlinear="relu",
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
コード例 #4
0
from espnet2.enh.separator.rnn_separator import RNNSeparator
from espnet2.layers.label_aggregation import LabelAggregate

enh_stft_encoder = STFTEncoder(
    n_fft=32,
    hop_length=16,
)

enh_stft_decoder = STFTDecoder(
    n_fft=32,
    hop_length=16,
)

enh_rnn_separator = RNNSeparator(
    input_dim=17,
    layer=1,
    unit=10,
    num_spk=1,
)

si_snr_loss = SISNRLoss()

fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
    n_fft=32,
    win_length=32,
    hop_length=24,
    n_mels=32,
)
コード例 #5
0
conv_encoder = ConvEncoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

conv_decoder = ConvDecoder(
    channel=15,
    kernel_size=32,
    stride=16,
)

rnn_separator = RNNSeparator(
    input_dim=15,
    layer=1,
    unit=10,
)

dprnn_separator = DPRNNSeparator(input_dim=15, layer=1, unit=10, segment_size=4)

tcn_separator = TCNSeparator(
    input_dim=15,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)

transformer_separator = TransformerSeparator(
コード例 #6
0
    input_dim=17, predict_noise=True, input_channels=[2, 2, 4]
)

dccrn_separator = DCCRNSeparator(
    input_dim=17, num_spk=1, kernel_num=[32, 64, 128], use_noise_mask=True
)

dprnn_separator = DPRNNSeparator(
    input_dim=17, predict_noise=True, layer=1, unit=10, segment_size=4
)

dptnet_separator = DPTNetSeparator(
    input_dim=16, predict_noise=True, layer=1, unit=10, segment_size=4
)

rnn_separator = RNNSeparator(input_dim=17, predict_noise=True, layer=1, unit=10)

svoice_separator = SVoiceSeparator(
    input_dim=17,
    enc_dim=4,
    kernel_size=4,
    hidden_size=4,
    num_spk=2,
    num_layers=2,
    segment_size=4,
    bidirectional=False,
    input_normalize=False,
)

tcn_separator = TCNSeparator(
    input_dim=17,