def test_tcn_separator_invalid_type(): with pytest.raises(ValueError): TCNSeparator( input_dim=10, nonlinear="fff", ) with pytest.raises(ValueError): TCNSeparator( input_dim=10, norm_type="xxx", )
def test_tcn_separator_forward_backward_real( input_dim, layer, num_spk, nonlinear, stack, bottleneck_dim, hidden_dim, kernel, causal, norm_type, ): model = TCNSeparator( input_dim=input_dim, num_spk=num_spk, layer=layer, stack=stack, bottleneck_dim=bottleneck_dim, hidden_dim=hidden_dim, kernel=kernel, causal=causal, norm_type=norm_type, nonlinear=nonlinear, ) x = torch.rand(2, 10, input_dim) x_lens = torch.tensor([10, 8], dtype=torch.long) maksed, flens, others = model(x, ilens=x_lens) assert isinstance(maksed[0], Tensor) assert len(maksed) == num_spk maksed[0].abs().mean().backward()
def test_tcn_separator_output(): x = torch.rand(2, 10, 10) x_lens = torch.tensor([10, 8], dtype=torch.long) for num_spk in range(1, 3): model = TCNSeparator( input_dim=10, layer=num_spk, stack=2, bottleneck_dim=3, hidden_dim=3, kernel=3, causal=False, ) model.eval() specs, _, others = model(x, x_lens) assert isinstance(specs, list) assert isinstance(others, dict) for n in range(num_spk): assert "mask_spk{}".format(n + 1) in others assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
kernel_size=32, stride=16, ) rnn_separator = RNNSeparator( input_dim=15, layer=1, unit=10, ) dprnn_separator = DPRNNSeparator(input_dim=15, layer=1, unit=10, segment_size=4) tcn_separator = TCNSeparator( input_dim=15, layer=2, stack=1, bottleneck_dim=10, hidden_dim=10, kernel=3, ) transformer_separator = TransformerSeparator( input_dim=15, adim=8, aheads=2, layers=2, linear_units=10, ) @pytest.mark.parametrize( "encoder, decoder",
input_dim=17, enc_dim=4, kernel_size=4, hidden_size=4, num_spk=2, num_layers=2, segment_size=4, bidirectional=False, input_normalize=False, ) tcn_separator = TCNSeparator( input_dim=17, predict_noise=True, layer=2, stack=1, bottleneck_dim=10, hidden_dim=10, kernel=3, ) transformer_separator = TransformerSeparator( input_dim=17, predict_noise=True, adim=8, aheads=2, layers=2, linear_units=10 ) si_snr_loss = SISNRLoss() tf_mse_loss = FrequencyDomainMSE() tf_l1_loss = FrequencyDomainL1() pit_wrapper = PITSolver(criterion=si_snr_loss) multilayer_pit_solver = MultiLayerPITSolver(criterion=si_snr_loss)