def test_rnn_separator_forward_backward_complex(input_dim, rnn_type, layer, unit, dropout, num_spk, nonlinear): model = RNNSeparator( input_dim=input_dim, rnn_type=rnn_type, layer=layer, unit=unit, dropout=dropout, num_spk=num_spk, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_rnn_separator_invalid_type(): with pytest.raises(ValueError): RNNSeparator( input_dim=10, rnn_type="rnn", layer=2, unit=10, dropout=0.1, num_spk=2, nonlinear="fff", )
def test_rnn_separator_output(): x = torch.rand(2, 10, 10) x_lens = torch.tensor([10, 8], dtype=torch.long) for num_spk in range(1, 3): model = RNNSeparator( input_dim=10, rnn_type="rnn", layer=2, unit=10, dropout=0.1, num_spk=num_spk, nonlinear="relu", ) model.eval() specs, _, others = model(x, x_lens) assert isinstance(specs, list) assert isinstance(others, dict) for n in range(num_spk): assert "mask_spk{}".format(n + 1) in others assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
from espnet2.enh.separator.rnn_separator import RNNSeparator from espnet2.layers.label_aggregation import LabelAggregate enh_stft_encoder = STFTEncoder( n_fft=32, hop_length=16, ) enh_stft_decoder = STFTDecoder( n_fft=32, hop_length=16, ) enh_rnn_separator = RNNSeparator( input_dim=17, layer=1, unit=10, num_spk=1, ) si_snr_loss = SISNRLoss() fix_order_solver = FixedOrderSolver(criterion=si_snr_loss) default_frontend = DefaultFrontend( fs=300, n_fft=32, win_length=32, hop_length=24, n_mels=32, )
conv_encoder = ConvEncoder( channel=15, kernel_size=32, stride=16, ) conv_decoder = ConvDecoder( channel=15, kernel_size=32, stride=16, ) rnn_separator = RNNSeparator( input_dim=15, layer=1, unit=10, ) dprnn_separator = DPRNNSeparator(input_dim=15, layer=1, unit=10, segment_size=4) tcn_separator = TCNSeparator( input_dim=15, layer=2, stack=1, bottleneck_dim=10, hidden_dim=10, kernel=3, ) transformer_separator = TransformerSeparator(
input_dim=17, predict_noise=True, input_channels=[2, 2, 4] ) dccrn_separator = DCCRNSeparator( input_dim=17, num_spk=1, kernel_num=[32, 64, 128], use_noise_mask=True ) dprnn_separator = DPRNNSeparator( input_dim=17, predict_noise=True, layer=1, unit=10, segment_size=4 ) dptnet_separator = DPTNetSeparator( input_dim=16, predict_noise=True, layer=1, unit=10, segment_size=4 ) rnn_separator = RNNSeparator(input_dim=17, predict_noise=True, layer=1, unit=10) svoice_separator = SVoiceSeparator( input_dim=17, enc_dim=4, kernel_size=4, hidden_size=4, num_spk=2, num_layers=2, segment_size=4, bidirectional=False, input_normalize=False, ) tcn_separator = TCNSeparator( input_dim=17,