Exemplo n.º 1
0
def bound_complex_mask(mask: ComplexTensor, bound_type="tanh"):
    r"""Bound a complex mask, as proposed in [1], section 3.2.

    Valid bound types, for a complex mask $M = |M| ⋅ e^{i φ(M)}$:

    - Unbounded ("UBD"): :math:`M_{\mathrm{UBD}} = M`
    - Sigmoid ("BDSS"): :math:`M_{\mathrm{BDSS}} = σ(|M|) e^{i σ(φ(M))}`
    - Tanh ("BDT"): :math:`M_{\mathrm{BDT}} = \mathrm{tanh}(|M|) e^{i φ(M)}`

    Args:
        bound_type (str or None): The type of bound to use, either of
            "tanh"/"bdt" (default), "sigmoid"/"bdss" or None/"bdt".

    References
        - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net",
          Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107
    """
    if bound_type in {"BDSS", "sigmoid"}:
        return on_reim(torch.sigmoid)(mask)
    elif bound_type in {"BDT", "tanh", "UBD", None}:
        mask_mag, mask_phase = transforms.magphase(
            transforms.from_torch_complex(mask))
        if bound_type in {"BDT", "tanh"}:
            mask_mag_bounded = torch.tanh(mask_mag)
        else:
            mask_mag_bounded = mask_mag
        return torch_complex_from_magphase(mask_mag_bounded, mask_phase)
    else:
        raise ValueError(f"Unknown mask bound {bound_type}")
Exemplo n.º 2
0
def test_torch_complex_format(dim, max_tested_ndim):
    # Random tensor shape
    tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
    # Make sure complex dimension has even shape
    tensor_shape[dim] = 2 * tensor_shape[dim]
    complex_tensor = torch.randn(tensor_shape)
    ta_tensor = transforms.to_torch_complex(complex_tensor, dim=dim)
    tensor_back = transforms.from_torch_complex(ta_tensor, dim=dim)
    assert_allclose(complex_tensor, tensor_back)
Exemplo n.º 3
0
 def apply_masks(self, tf_rep, est_masks):
     if self.masking_method == "tanh":
         mag_est_masks = torch.abs(est_masks)
         # Scaling factor confines mask to unit disc (see [1])
         scaling_factor = torch.tanh(mag_est_masks) / mag_est_masks
         masked_tf_rep = scaling_factor * est_masks * tf_rep.unsqueeze(1)
     else:
         masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
     return from_torch_complex(masked_tf_rep)
Exemplo n.º 4
0
 def apply_masks(self, tf_rep, est_masks):
     masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
     return from_torch_complex(masked_tf_rep)
Exemplo n.º 5
0
 def apply_masks(self, tf_rep, est_masks):
     masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
     # Pad Nyquist frequency bin
     return from_torch_complex(
         torch.nn.functional.pad(masked_tf_rep, [0, 0, 0, 1]))
Exemplo n.º 6
0
    RTFMVDRBeamformer,
    SDWMWFBeamformer,
    GEVBeamformer,
    stable_cholesky,
)

torch_has_complex_support = tuple(map(
    int,
    torch.__version__.split(".")[:2])) >= (1, 8)

_stft, _istft = make_enc_dec("stft",
                             kernel_size=512,
                             n_filters=512,
                             stride=128)
stft = lambda x: tr.to_torch_complex(_stft(x))
istft = lambda x: _istft(tr.from_torch_complex(x))


@pytest.mark.skipif(not torch_has_complex_support, "No complex support ")
def _default_beamformer_test(beamformer: Beamformer,
                             n_mics=4,
                             *args,
                             **kwargs):
    scm = SCM()

    speech = torch.randn(1, n_mics, 16000 * 6)
    noise = torch.randn(1, n_mics, 16000 * 6)
    mix = speech + noise
    # GeV Beamforming
    mix_stft = stft(mix)
    speech_stft = stft(speech)