Пример #1
0
def test_dptnet(training, loss_wrappers):
    encoder = ConvEncoder(channel=16, kernel_size=36, stride=18)
    decoder = ConvDecoder(channel=16, kernel_size=36, stride=18)

    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=dptnet_separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=loss_wrappers,
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }
    loss, stats, weight = enh_model(**kwargs)
Пример #2
0

label_aggregator = LabelAggregate(
    win_length=32,
    hop_length=16,
)

enh_encoder = ConvEncoder(
    channel=17,
    kernel_size=32,
    stride=16,
)

enh_decoder = ConvDecoder(
    channel=17,
    kernel_size=32,
    stride=16,
)

tcn_separator = TCNSeparatorNomask(
    input_dim=enh_encoder.output_dim,
    layer=2,
    stack=1,
    bottleneck_dim=10,
    hidden_dim=10,
    kernel=3,
)

mask_module = MultiMask(
    bottleneck_dim=10,
    max_num_spk=3,
Пример #3
0
)

stft_decoder = STFTDecoder(
    n_fft=16,
    hop_length=8,
)

conv_encoder = ConvEncoder(
    channel=9,
    kernel_size=20,
    stride=10,
)

conv_decoder = ConvDecoder(
    channel=9,
    kernel_size=20,
    stride=10,
)

rnn_separator = RNNSeparator(
    input_dim=9,
    layer=1,
    unit=10,
)

dprnn_separator = DPRNNSeparator(input_dim=9, layer=1, unit=10, segment_size=4)

tcn_separator = TCNSeparator(
    input_dim=9,
    layer=2,
    stack=1,
Пример #4
0
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.transformer_separator import TransformerSeparator

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

stft_encoder = STFTEncoder(n_fft=32, hop_length=16)

stft_encoder_bultin_complex = STFTEncoder(n_fft=32,
                                          hop_length=16,
                                          use_builtin_complex=True)

stft_decoder = STFTDecoder(n_fft=32, hop_length=16)

conv_encoder = ConvEncoder(channel=17, kernel_size=36, stride=18)

conv_decoder = ConvDecoder(channel=17, kernel_size=36, stride=18)

null_encoder = NullEncoder()

null_decoder = NullDecoder()

dc_crn_separator = DC_CRNSeparator(input_dim=17, input_channels=[2, 2, 4])

dccrn_separator = DCCRNSeparator(input_dim=17,
                                 num_spk=1,
                                 kernel_num=[32, 64, 128])

dprnn_separator = DPRNNSeparator(input_dim=17,
                                 layer=1,
                                 unit=10,
                                 segment_size=4)