def test_dan_separator_forward_backward_complex(input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear): model = DANSeparator( input_dim=input_dim, rnn_type=rnn_type, layer=layer, unit=unit, dropout=dropout, num_spk=num_spk, emb_D=emb_D, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) o = [] for i in range(num_spk): o.append(ComplexTensor(real, imag)) sep_others = {} sep_others["feature_ref"] = o masked, flens, others = model(x, ilens=x_lens, additional=sep_others) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_dpcl_separator_forward_backward_complex(input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear): model = DPCLSeparator( input_dim=input_dim, rnn_type=rnn_type, layer=layer, unit=unit, dropout=dropout, num_spk=num_spk, emb_D=emb_D, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert "tf_embedding" in others others["tf_embedding"].abs().mean().backward()
def test_rnn_separator_forward_backward_complex(input_dim, rnn_type, layer, unit, dropout, num_spk, nonlinear): model = RNNSeparator( input_dim=input_dim, rnn_type=rnn_type, layer=layer, unit=unit, dropout=dropout, num_spk=num_spk, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_tf_dpcl_loss_criterion_forward(loss_type): criterion = FrequencyDomainDPCL(loss_type=loss_type) batch = 2 inf = torch.rand(batch, 10 * 200, 40) ref_spec = [ ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), ] ref = [abs(r) for r in ref_spec] loss = criterion(ref, inf) assert loss.shape == (batch, ), "Invalid loss shape with " + criterion.name
def test_STFTDecoder_backward(n_fft, win_length, hop_length, window, center, normalized, onesided): decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) x = ComplexTensor(real, imag) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(x, x_lens) y.sum().backward()
def test_tf_domain_criterion_forward(criterion_class, mask_type, compute_on_mask): criterion = criterion_class(compute_on_mask=compute_on_mask, mask_type=mask_type) batch = 2 inf = [torch.rand(batch, 10, 200)] ref_spec = [ ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)) ] mix_spec = ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)) if compute_on_mask: ref = criterion.create_mask_label(mix_spec, ref_spec) else: ref = [abs(r) for r in ref_spec] loss = criterion(ref[0], inf[0]) assert loss.shape == (batch, )
def get_power_spectral_density_matrix( complex_tensor: ComplexTensor) -> ComplexTensor: """ Cross-channel power spectral density (PSD) matrix Args: complex_tensor: [..., F, C, T] Returns psd: [..., F, C, C] """ # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2] return FC.einsum("...ct,...et->...tce", [complex_tensor, complex_tensor.conj()])
def apply_crf_filter(cRM_filter: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: """ Apply complex Ratio Filter Args: cRM_filter: complex Ratio Filter mix: mixture Returns: [B, C, F, T] """ # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T] es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix]) return es
def test_dccrn_separator_forward_backward_complex( input_dim, num_spk, rnn_layer, rnn_units, masking_mode, use_clstm, bidirectional, use_cbn, kernel_size, use_builtin_complex, use_noise_mask, ): model = DCCRNSeparator( input_dim=input_dim, num_spk=num_spk, rnn_layer=rnn_layer, rnn_units=rnn_units, masking_mode=masking_mode, use_clstm=use_clstm, bidirectional=bidirectional, use_cbn=use_cbn, kernel_size=kernel_size, kernel_num=[ 32, 64, 128, ], use_builtin_complex=use_builtin_complex, use_noise_mask=use_noise_mask, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) if use_builtin_complex and is_torch_1_9_plus: assert isinstance(masked[0], torch.Tensor) else: assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_dc_crn_separator_forward_backward_complex( input_dim, num_spk, input_channels, enc_hid_channels, enc_layers, glstm_groups, glstm_layers, glstm_bidirectional, glstm_rearrange, mode, ): model = DC_CRNSeparator( input_dim=input_dim, num_spk=num_spk, input_channels=input_channels, enc_hid_channels=enc_hid_channels, enc_kernel_size=(1, 3), enc_padding=(0, 1), enc_last_kernel_size=(1, 3), enc_last_stride=(1, 2), enc_last_padding=(0, 1), enc_layers=enc_layers, skip_last_kernel_size=(1, 3), skip_last_stride=(1, 1), skip_last_padding=(0, 1), glstm_groups=glstm_groups, glstm_layers=glstm_layers, glstm_bidirectional=glstm_bidirectional, glstm_rearrange=glstm_rearrange, mode=mode, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor( real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert is_complex(masked[0]) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_transformer_separator_forward_backward_complex( input_dim, adim, layers, aheads, linear_units, num_spk, nonlinear, positionwise_layer_type, normalize_before, concat_after, dropout_rate, positional_dropout_rate, attention_dropout_rate, use_scaled_pos_enc, ): model = TransformerSeparator( input_dim=input_dim, num_spk=num_spk, adim=adim, aheads=aheads, layers=layers, linear_units=linear_units, positionwise_layer_type=positionwise_layer_type, normalize_before=normalize_before, concat_after=concat_after, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, use_scaled_pos_enc=use_scaled_pos_enc, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_dptnet_separator_forward_backward_complex( input_dim, post_enc_relu, rnn_type, bidirectional, num_spk, unit, att_heads, dropout, activation, norm_type, layer, segment_size, nonlinear, ): model = DPTNetSeparator( input_dim=input_dim, post_enc_relu=post_enc_relu, rnn_type=rnn_type, bidirectional=bidirectional, num_spk=num_spk, unit=unit, att_heads=att_heads, dropout=dropout, activation=activation, norm_type=norm_type, layer=layer, segment_size=segment_size, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_dc_crn_separator_output(): real = torch.rand(2, 10, 17) imag = torch.rand(2, 10, 17) x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor( real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) for num_spk in range(1, 3): model = DC_CRNSeparator( input_dim=17, num_spk=num_spk, input_channels=[2, 2, 4], ) model.eval() specs, _, others = model(x, x_lens) assert isinstance(specs, list) assert isinstance(others, dict) for n in range(num_spk): assert "mask_spk{}".format(n + 1) in others assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
def test_dc_crn_separator_multich_input( num_spk, input_channels, enc_kernel_size, enc_padding, enc_last_kernel_size, enc_last_stride, enc_last_padding, skip_last_kernel_size, skip_last_stride, skip_last_padding, ): model = DC_CRNSeparator( input_dim=33, num_spk=num_spk, input_channels=input_channels, enc_hid_channels=2, enc_kernel_size=enc_kernel_size, enc_padding=enc_padding, enc_last_kernel_size=enc_last_kernel_size, enc_last_stride=enc_last_stride, enc_last_padding=enc_last_padding, enc_layers=3, skip_last_kernel_size=skip_last_kernel_size, skip_last_stride=skip_last_stride, skip_last_padding=skip_last_padding, ) model.train() real = torch.rand(2, 10, input_channels[0] // 2, 33) imag = torch.rand(2, 10, input_channels[0] // 2, 33) x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor( real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert is_complex(masked[0]) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_STFTDecoder_backward( n_fft, win_length, hop_length, window, center, normalized, onesided ): if not is_torch_1_2_plus: pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task") decoder = STFTDecoder( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True) x = ComplexTensor(real, imag) x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long) y, ilens = decoder(x, x_lens) y.sum().backward()
def test_dpcl_e2e_separator_forward_backward_complex( input_dim, rnn_type, layer, unit, dropout, num_spk, predict_noise, emb_D, nonlinear, alpha, max_iteration, ): model = DPCLE2ESeparator( input_dim=input_dim, rnn_type=rnn_type, layer=layer, unit=unit, dropout=dropout, num_spk=num_spk, predict_noise=predict_noise, emb_D=emb_D, nonlinear=nonlinear, alpha=alpha, max_iteration=max_iteration, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_rnn_separator_output(): real = torch.rand(2, 10, 9) imag = torch.rand(2, 10, 9) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) for num_spk in range(1, 3): model = DCCRNSeparator( input_dim=9, num_spk=num_spk, kernel_num=[ 32, 64, 128, ], ) model.eval() specs, _, others = model(x, x_lens) assert isinstance(specs, list) assert isinstance(others, dict) for n in range(num_spk): assert "mask_spk{}".format(n + 1) in others assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
def test_skim_separator_forward_backward_complex( input_dim, layer, causal, unit, dropout, num_spk, nonlinear, mem_type, segment_size, seg_overlap, ): model = SkiMSeparator( input_dim=input_dim, causal=causal, num_spk=num_spk, nonlinear=nonlinear, layer=layer, unit=unit, segment_size=segment_size, dropout=dropout, mem_type=mem_type, seg_overlap=seg_overlap, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert isinstance(masked[0], ComplexTensor) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_tcn_separator_forward_backward_complex( input_dim, layer, num_spk, nonlinear, stack, bottleneck_dim, hidden_dim, kernel, causal, norm_type, ): model = TCNSeparator( input_dim=input_dim, num_spk=num_spk, layer=layer, stack=stack, bottleneck_dim=bottleneck_dim, hidden_dim=hidden_dim, kernel=kernel, causal=causal, norm_type=norm_type, nonlinear=nonlinear, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = ComplexTensor(real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) maksed, flens, others = model(x, ilens=x_lens) assert isinstance(maksed[0], ComplexTensor) assert len(maksed) == num_spk maksed[0].abs().mean().backward()
def apply_beamforming_vector(beamforming_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: # [..., C] x [..., C, T] => [..., T] # There's no relationship between frequencies. es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix]) return es