def bound_complex_mask(mask: ComplexTensor, bound_type="tanh"): r"""Bound a complex mask, as proposed in [1], section 3.2. Valid bound types, for a complex mask $M = |M| ⋅ e^{i φ(M)}$: - Unbounded ("UBD"): :math:`M_{\mathrm{UBD}} = M` - Sigmoid ("BDSS"): :math:`M_{\mathrm{BDSS}} = σ(|M|) e^{i σ(φ(M))}` - Tanh ("BDT"): :math:`M_{\mathrm{BDT}} = \mathrm{tanh}(|M|) e^{i φ(M)}` Args: bound_type (str or None): The type of bound to use, either of "tanh"/"bdt" (default), "sigmoid"/"bdss" or None/"bdt". References - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net", Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107 """ if bound_type in {"BDSS", "sigmoid"}: return on_reim(torch.sigmoid)(mask) elif bound_type in {"BDT", "tanh", "UBD", None}: mask_mag, mask_phase = transforms.magphase( transforms.from_torch_complex(mask)) if bound_type in {"BDT", "tanh"}: mask_mag_bounded = torch.tanh(mask_mag) else: mask_mag_bounded = mask_mag return torch_complex_from_magphase(mask_mag_bounded, mask_phase) else: raise ValueError(f"Unknown mask bound {bound_type}")
def test_torch_complex_format(dim, max_tested_ndim): # Random tensor shape tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)] # Make sure complex dimension has even shape tensor_shape[dim] = 2 * tensor_shape[dim] complex_tensor = torch.randn(tensor_shape) ta_tensor = transforms.to_torch_complex(complex_tensor, dim=dim) tensor_back = transforms.from_torch_complex(ta_tensor, dim=dim) assert_allclose(complex_tensor, tensor_back)
def apply_masks(self, tf_rep, est_masks): if self.masking_method == "tanh": mag_est_masks = torch.abs(est_masks) # Scaling factor confines mask to unit disc (see [1]) scaling_factor = torch.tanh(mag_est_masks) / mag_est_masks masked_tf_rep = scaling_factor * est_masks * tf_rep.unsqueeze(1) else: masked_tf_rep = est_masks * tf_rep.unsqueeze(1) return from_torch_complex(masked_tf_rep)
def apply_masks(self, tf_rep, est_masks): masked_tf_rep = est_masks * tf_rep.unsqueeze(1) return from_torch_complex(masked_tf_rep)
def apply_masks(self, tf_rep, est_masks): masked_tf_rep = est_masks * tf_rep.unsqueeze(1) # Pad Nyquist frequency bin return from_torch_complex( torch.nn.functional.pad(masked_tf_rep, [0, 0, 0, 1]))
RTFMVDRBeamformer, SDWMWFBeamformer, GEVBeamformer, stable_cholesky, ) torch_has_complex_support = tuple(map( int, torch.__version__.split(".")[:2])) >= (1, 8) _stft, _istft = make_enc_dec("stft", kernel_size=512, n_filters=512, stride=128) stft = lambda x: tr.to_torch_complex(_stft(x)) istft = lambda x: _istft(tr.from_torch_complex(x)) @pytest.mark.skipif(not torch_has_complex_support, "No complex support ") def _default_beamformer_test(beamformer: Beamformer, n_mics=4, *args, **kwargs): scm = SCM() speech = torch.randn(1, n_mics, 16000 * 6) noise = torch.randn(1, n_mics, 16000 * 6) mix = speech + noise # GeV Beamforming mix_stft = stft(mix) speech_stft = stft(speech)