Exemple #1
0
def test_dan_separator_forward_backward_complex(input_dim, rnn_type, layer,
                                                unit, dropout, num_spk, emb_D,
                                                nonlinear):
    model = DANSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        emb_D=emb_D,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    o = []
    for i in range(num_spk):
        o.append(ComplexTensor(real, imag))

    sep_others = {}
    sep_others["feature_ref"] = o

    masked, flens, others = model(x, ilens=x_lens, additional=sep_others)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
Exemple #2
0
def test_dpcl_separator_forward_backward_complex(input_dim, rnn_type, layer,
                                                 unit, dropout, num_spk, emb_D,
                                                 nonlinear):
    model = DPCLSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        emb_D=emb_D,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert "tf_embedding" in others

    others["tf_embedding"].abs().mean().backward()
Exemple #3
0
def test_rnn_separator_forward_backward_complex(input_dim, rnn_type, layer,
                                                unit, dropout, num_spk,
                                                nonlinear):
    model = RNNSeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
Exemple #4
0
def test_tf_dpcl_loss_criterion_forward(loss_type):

    criterion = FrequencyDomainDPCL(loss_type=loss_type)

    batch = 2
    inf = torch.rand(batch, 10 * 200, 40)
    ref_spec = [
        ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)),
        ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)),
        ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)),
    ]

    ref = [abs(r) for r in ref_spec]

    loss = criterion(ref, inf)
    assert loss.shape == (batch, ), "Invalid loss shape with " + criterion.name
def test_STFTDecoder_backward(n_fft, win_length, hop_length, window, center,
                              normalized, onesided):
    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    real = torch.rand(2,
                      300,
                      n_fft // 2 + 1 if onesided else n_fft,
                      requires_grad=True)
    imag = torch.rand(2,
                      300,
                      n_fft // 2 + 1 if onesided else n_fft,
                      requires_grad=True)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([300 * hop_length, 295 * hop_length],
                          dtype=torch.long)
    y, ilens = decoder(x, x_lens)
    y.sum().backward()
Exemple #6
0
def test_tf_domain_criterion_forward(criterion_class, mask_type,
                                     compute_on_mask):

    criterion = criterion_class(compute_on_mask=compute_on_mask,
                                mask_type=mask_type)

    batch = 2
    inf = [torch.rand(batch, 10, 200)]
    ref_spec = [
        ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200))
    ]
    mix_spec = ComplexTensor(torch.rand(batch, 10, 200),
                             torch.rand(batch, 10, 200))

    if compute_on_mask:
        ref = criterion.create_mask_label(mix_spec, ref_spec)
    else:
        ref = [abs(r) for r in ref_spec]

    loss = criterion(ref[0], inf[0])
    assert loss.shape == (batch, )
def get_power_spectral_density_matrix(
        complex_tensor: ComplexTensor) -> ComplexTensor:
    """
    Cross-channel power spectral density (PSD) matrix

    Args:
        complex_tensor: [..., F, C, T]

    Returns
        psd: [..., F, C, C]
    """
    # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2]
    return FC.einsum("...ct,...et->...tce",
                     [complex_tensor, complex_tensor.conj()])
def apply_crf_filter(cRM_filter: ComplexTensor,
                     mix: ComplexTensor) -> ComplexTensor:
    """
    Apply complex Ratio Filter

    Args:
        cRM_filter: complex Ratio Filter
        mix: mixture

    Returns:
        [B, C, F, T]
    """
    # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T]
    es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix])
    return es
def test_dccrn_separator_forward_backward_complex(
    input_dim,
    num_spk,
    rnn_layer,
    rnn_units,
    masking_mode,
    use_clstm,
    bidirectional,
    use_cbn,
    kernel_size,
    use_builtin_complex,
    use_noise_mask,
):
    model = DCCRNSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        rnn_layer=rnn_layer,
        rnn_units=rnn_units,
        masking_mode=masking_mode,
        use_clstm=use_clstm,
        bidirectional=bidirectional,
        use_cbn=use_cbn,
        kernel_size=kernel_size,
        kernel_num=[
            32,
            64,
            128,
        ],
        use_builtin_complex=use_builtin_complex,
        use_noise_mask=use_noise_mask,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    if use_builtin_complex and is_torch_1_9_plus:
        assert isinstance(masked[0], torch.Tensor)
    else:
        assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
def test_dc_crn_separator_forward_backward_complex(
    input_dim,
    num_spk,
    input_channels,
    enc_hid_channels,
    enc_layers,
    glstm_groups,
    glstm_layers,
    glstm_bidirectional,
    glstm_rearrange,
    mode,
):
    model = DC_CRNSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        input_channels=input_channels,
        enc_hid_channels=enc_hid_channels,
        enc_kernel_size=(1, 3),
        enc_padding=(0, 1),
        enc_last_kernel_size=(1, 3),
        enc_last_stride=(1, 2),
        enc_last_padding=(0, 1),
        enc_layers=enc_layers,
        skip_last_kernel_size=(1, 3),
        skip_last_stride=(1, 1),
        skip_last_padding=(0, 1),
        glstm_groups=glstm_groups,
        glstm_layers=glstm_layers,
        glstm_bidirectional=glstm_bidirectional,
        glstm_rearrange=glstm_rearrange,
        mode=mode,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert is_complex(masked[0])
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
def test_transformer_separator_forward_backward_complex(
    input_dim,
    adim,
    layers,
    aheads,
    linear_units,
    num_spk,
    nonlinear,
    positionwise_layer_type,
    normalize_before,
    concat_after,
    dropout_rate,
    positional_dropout_rate,
    attention_dropout_rate,
    use_scaled_pos_enc,
):
    model = TransformerSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        adim=adim,
        aheads=aheads,
        layers=layers,
        linear_units=linear_units,
        positionwise_layer_type=positionwise_layer_type,
        normalize_before=normalize_before,
        concat_after=concat_after,
        dropout_rate=dropout_rate,
        positional_dropout_rate=positional_dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
        use_scaled_pos_enc=use_scaled_pos_enc,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
def test_dptnet_separator_forward_backward_complex(
    input_dim,
    post_enc_relu,
    rnn_type,
    bidirectional,
    num_spk,
    unit,
    att_heads,
    dropout,
    activation,
    norm_type,
    layer,
    segment_size,
    nonlinear,
):
    model = DPTNetSeparator(
        input_dim=input_dim,
        post_enc_relu=post_enc_relu,
        rnn_type=rnn_type,
        bidirectional=bidirectional,
        num_spk=num_spk,
        unit=unit,
        att_heads=att_heads,
        dropout=dropout,
        activation=activation,
        norm_type=norm_type,
        layer=layer,
        segment_size=segment_size,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
def test_dc_crn_separator_output():
    real = torch.rand(2, 10, 17)
    imag = torch.rand(2, 10, 17)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DC_CRNSeparator(
            input_dim=17,
            num_spk=num_spk,
            input_channels=[2, 2, 4],
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
def test_dc_crn_separator_multich_input(
    num_spk,
    input_channels,
    enc_kernel_size,
    enc_padding,
    enc_last_kernel_size,
    enc_last_stride,
    enc_last_padding,
    skip_last_kernel_size,
    skip_last_stride,
    skip_last_padding,
):
    model = DC_CRNSeparator(
        input_dim=33,
        num_spk=num_spk,
        input_channels=input_channels,
        enc_hid_channels=2,
        enc_kernel_size=enc_kernel_size,
        enc_padding=enc_padding,
        enc_last_kernel_size=enc_last_kernel_size,
        enc_last_stride=enc_last_stride,
        enc_last_padding=enc_last_padding,
        enc_layers=3,
        skip_last_kernel_size=skip_last_kernel_size,
        skip_last_stride=skip_last_stride,
        skip_last_padding=skip_last_padding,
    )
    model.train()

    real = torch.rand(2, 10, input_channels[0] // 2, 33)
    imag = torch.rand(2, 10, input_channels[0] // 2, 33)
    x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor(
        real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert is_complex(masked[0])
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
Exemple #15
0
def test_STFTDecoder_backward(
    n_fft, win_length, hop_length, window, center, normalized, onesided
):
    if not is_torch_1_2_plus:
        pytest.skip("Pytorch Version Under 1.2 is not supported for Enh task")
    decoder = STFTDecoder(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        window=window,
        center=center,
        normalized=normalized,
        onesided=onesided,
    )

    real = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True)
    imag = torch.rand(2, 300, n_fft // 2 + 1 if onesided else n_fft, requires_grad=True)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([300 * hop_length, 295 * hop_length], dtype=torch.long)
    y, ilens = decoder(x, x_lens)
    y.sum().backward()
Exemple #16
0
def test_dpcl_e2e_separator_forward_backward_complex(
    input_dim,
    rnn_type,
    layer,
    unit,
    dropout,
    num_spk,
    predict_noise,
    emb_D,
    nonlinear,
    alpha,
    max_iteration,
):
    model = DPCLE2ESeparator(
        input_dim=input_dim,
        rnn_type=rnn_type,
        layer=layer,
        unit=unit,
        dropout=dropout,
        num_spk=num_spk,
        predict_noise=predict_noise,
        emb_D=emb_D,
        nonlinear=nonlinear,
        alpha=alpha,
        max_iteration=max_iteration,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
def test_rnn_separator_output():
    real = torch.rand(2, 10, 9)
    imag = torch.rand(2, 10, 9)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    for num_spk in range(1, 3):
        model = DCCRNSeparator(
            input_dim=9,
            num_spk=num_spk,
            kernel_num=[
                32,
                64,
                128,
            ],
        )
        model.eval()
        specs, _, others = model(x, x_lens)
        assert isinstance(specs, list)
        assert isinstance(others, dict)
        for n in range(num_spk):
            assert "mask_spk{}".format(n + 1) in others
            assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape
Exemple #18
0
def test_skim_separator_forward_backward_complex(
    input_dim,
    layer,
    causal,
    unit,
    dropout,
    num_spk,
    nonlinear,
    mem_type,
    segment_size,
    seg_overlap,
):
    model = SkiMSeparator(
        input_dim=input_dim,
        causal=causal,
        num_spk=num_spk,
        nonlinear=nonlinear,
        layer=layer,
        unit=unit,
        segment_size=segment_size,
        dropout=dropout,
        mem_type=mem_type,
        seg_overlap=seg_overlap,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    masked, flens, others = model(x, ilens=x_lens)

    assert isinstance(masked[0], ComplexTensor)
    assert len(masked) == num_spk

    masked[0].abs().mean().backward()
Exemple #19
0
def test_tcn_separator_forward_backward_complex(
    input_dim,
    layer,
    num_spk,
    nonlinear,
    stack,
    bottleneck_dim,
    hidden_dim,
    kernel,
    causal,
    norm_type,
):
    model = TCNSeparator(
        input_dim=input_dim,
        num_spk=num_spk,
        layer=layer,
        stack=stack,
        bottleneck_dim=bottleneck_dim,
        hidden_dim=hidden_dim,
        kernel=kernel,
        causal=causal,
        norm_type=norm_type,
        nonlinear=nonlinear,
    )
    model.train()

    real = torch.rand(2, 10, input_dim)
    imag = torch.rand(2, 10, input_dim)
    x = ComplexTensor(real, imag)
    x_lens = torch.tensor([10, 8], dtype=torch.long)

    maksed, flens, others = model(x, ilens=x_lens)

    assert isinstance(maksed[0], ComplexTensor)
    assert len(maksed) == num_spk

    maksed[0].abs().mean().backward()
def apply_beamforming_vector(beamforming_vector: ComplexTensor,
                             mix: ComplexTensor) -> ComplexTensor:
    # [..., C] x [..., C, T] => [..., T]
    # There's no relationship between frequencies.
    es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix])
    return es