예제 #1
0
 def forward(self, x: ComplexTensor, eps: float = 1e-5) -> th.Tensor:
     """
     Args:
         x: complex tensor, N x C x F x T
     Return:
         y: enhanced features, N x P x G x T
     """
     assert isinstance(x, ComplexTensor)
     if x.dim() not in [3, 4]:
         raise RuntimeError(f"Expect 3/4D tensor, got {x.dim()} instead")
     if x.dim() == 3:
         x = x[None, ...]
     # N x P x T x F
     b = self.beam(x, trans=True, cplx=True)
     if self.spectra_complex:
         # N x P x T x G
         w = self.proj(b)
         # log + abs: N x P x T x G
         w = (w + eps).abs()
     else:
         # N x P x T x F
         p = (b + eps).abs()
         # N x P x T x G
         w = th.relu(self.proj(p)) + eps
     z = th.log(w)
     # N x P x T x G
     if self.norm:
         z = self.norm(z)
     # N x P x T x G => N x T x P x G
     z = z.transpose(1, 2).contiguous()
     # N x T x BD
     z = z.view(*z.shape[:2], -1)
     return z
예제 #2
0
def trace(cplx_mat: ComplexTensor) -> ComplexTensor:
    """
    Return trace of a complex matrices
    """
    mat_size = cplx_mat.size()
    diag_index = th.eye(mat_size[-1], dtype=th.bool,
                        device=cplx_mat.device).expand(*mat_size)
    return cplx_mat.masked_select(diag_index).view(*mat_size[:-1]).sum(-1)
예제 #3
0
 def _norm_abs(self, obs: ComplexTensor) -> ComplexTensor:
     """
     Normalize complex-valued STFTs
     """
     mag = obs.abs()
     mag_norm = th.norm(mag, p=2, dim=1, keepdim=True)
     mag = mag / th.clamp(mag_norm, min=EPSILON)
     obs = ComplexTensor(mag, obs.angle(), polar=True)
     return obs
예제 #4
0
 def _derive_weight(self,
                    Rs: ComplexTensor,
                    Rn: ComplexTensor,
                    u: th.Tensor,
                    eps: float = 1e-5) -> ComplexTensor:
     """
     Compute mvdr beam weights
     Args:
         Rs, Rn: speech & noise covariance matrices, N x F x C x C
         u: reference selection vector, N x C
     Return:
         weight: N x F x C
     """
     C = Rn.shape[-1]
     I = th.eye(C, device=Rn.device, dtype=Rn.dtype)
     Rn = Rn + I * eps
     # N x F x C x C
     Rn_inv = Rn.inverse()
     # N x F x C x C: einsum("...ij,...jk->...ik", Rn_inv, Rs)
     Rn_inv_Rs = Rn_inv @ Rs
     # N x F
     tr_Rn_inv_Rs = trace(Rn_inv_Rs) + eps
     # N x F x C: einsum("...fnc,...c->...fn", Rn_inv_Rs, u)
     Rn_inv_Rs_u = (Rn_inv_Rs * u[:, None, None, :]).sum(-1)
     # N x F x C
     weight = Rn_inv_Rs_u / tr_Rn_inv_Rs[..., None]
     return weight
예제 #5
0
 def log_pdf(self, mask: th.Tensor, obs: ComplexTensor) -> th.Tensor:
     """
     Compute log-pdf of the cacgmm distributions
     Args:
         mask (Tensor): N x F x T
         obs (ComplexTensor): N x F x C x T
     Return:
         log_pdf (Tensor)
     """
     _, _, C, _ = obs.shape
     # N x F x C x C
     Bk = estimate_covar(mask, obs, eps=self.eps)
     # add to diag
     I = th.eye(C, device=Bk.device, dtype=Bk.dtype)
     Bk = Bk + I * self.eps
     # N x F
     Dk = hermitian_det(Bk, eps=self.eps)
     # N x F x C x C
     Bk_inv = Bk.inverse()
     # N x F x T: einsum("...xt,...xy,...yt->...t", obs.conj(), Bk_inv, obs)
     K = (obs.conj() * (Bk_inv @ obs)).sum(-2)
     K = th.clamp(K.real, min=self.eps)
     # N x F x T
     log_pdf = -C * th.log(K) - th.log(Dk[..., None])
     # N x F x T
     return log_pdf
예제 #6
0
    def forward(self, wav_pad: th.Tensor,
                wav_len: Optional[th.Tensor]) -> EnhReturnType:
        """
        Args:
            wav_pad (Tensor): raw waveform, N x C x S or N x S
            wav_len (Tensor or None): number samples in wav_pad, N or None
        Return:
            feats (Tensor): spatial + spectral features, N x T x ...
            cplx (ComplexTensor): STFT coefficients, N x (C) x F x T
            num_frames (Tensor or None): number frames in each batch, N or None
        """
        # magnitude & phase: N x C x F x T
        mag, pha = self.forward_stft(wav_pad)
        # STFT coefficients: N x C x F x T
        cplx = ComplexTensor(mag, pha, polar=True)

        feats = []
        # magnitude transform
        if self.mag_transform:
            # N x (C) x T x F => N x T x F
            feats.append(self.mag_transform(mag))
        # ipd transform
        if self.ipd_transform:
            # N x C x F x T => N x ... x T
            feats.append(self.ipd_transform(pha))
        # concatenate: N x T x ...
        num_frames = self.num_frames(wav_len)
        if len(feats):
            feats = check_valid(th.cat(feats, -1), num_frames)[0]
        else:
            feats = None
        return feats, cplx, num_frames
예제 #7
0
 def forward(self, Rs: ComplexTensor) -> th.Tensor:
     """
     Args:
         Rs: complex, N x F x C x C
     Return:
         u: real, N x C
     """
     C = Rs.shape[-1]
     I = th.eye(C, device=Rs.device, dtype=th.bool)
     # diag is zero, N x F x C
     Rs = Rs.masked_fill(I, 0).sum(-1) / (C - 1)
     # N x C x A
     proj = self.proj(Rs.abs().transpose(1, 2))
     # N x C x 1
     gvec = self.gvec(th.tanh(proj))
     # N x C
     return tf.softmax(gvec.squeeze(-1), -1)
예제 #8
0
 def forward(self, x: ComplexTensor) -> ComplexTensor:
     """
     args:
         x: complex tensor
     return:
         y: complex tensor
     """
     assert isinstance(x, ComplexTensor)
     r = self.real(x.real) - self.imag(x.imag)
     i = self.real(x.imag) + self.imag(x.real)
     return ComplexTensor(r, i)
예제 #9
0
def test_fixed_beamformer(batch_size, num_channels, num_bins, num_directions):
    beamformer = FixedBeamformer(num_directions, num_channels, num_bins)
    num_frames = th.randint(50, 100, (1, )).item()
    inp_r = th.rand(batch_size, num_channels, num_bins, num_frames)
    inp_i = th.rand(batch_size, num_channels, num_bins, num_frames)
    inp_c = ComplexTensor(inp_r, inp_i)
    out_b = beamformer(inp_c)
    assert out_b.shape == th.Size(
        [batch_size, num_directions, num_bins, num_frames])
    out_b = beamformer(inp_c, beam=0)
    assert out_b.shape == th.Size([batch_size, num_bins, num_frames])
예제 #10
0
 def forward(self,
             x: ComplexTensor,
             add_abs: bool = False,
             eps: float = 1e-5) -> Union[ComplexTensor, th.Tensor]:
     # x: complex tensor
     assert isinstance(x, ComplexTensor)
     xr, xi = x.real, x.imag
     br = self.real(xr) - self.imag(xi)
     bi = self.real(xi) + self.imag(xr)
     if not add_abs:
         return ComplexTensor(br, bi)
     else:
         return (br**2 + bi**2 + eps)**0.5
예제 #11
0
 def forward(
     self,
     x: ComplexTensor,
     beam: Optional[th.Tensor] = None,
     squeeze: bool = False,
     trans: bool = False,
     cplx: bool = True
 ) -> Union[ComplexTensor, Tuple[th.Tensor, th.Tensor]]:
     """
     Args:
         x (Complex Tensor): N x C x F x T
         beam (Tensor or None): N
     Return:
         1) (Tensor, Tensor): N x (B) x F x T
         2) (ComplexTensor): N x (B) x F x T
     """
     r, i = x.real, x.imag
     if r.dim() != i.dim() and r.dim() != 4:
         raise RuntimeError(
             f"FixBeamformer accept 4D tensor, got {r.dim()}")
     if self.real.shape[1] != r.shape[1]:
         raise RuntimeError(f"Number of channels mismatch: "
                            f"{r.shape[1]} vs {self.real.shape[1]}")
     if beam is None:
         # output all the beam
         br = th.sum(r.unsqueeze(1) * self.real, 2) + th.sum(
             i.unsqueeze(1) * self.imag, 2)
         bi = th.sum(i.unsqueeze(1) * self.real, 2) - th.sum(
             r.unsqueeze(1) * self.imag, 2)
     else:
         # output selected beam
         br = th.sum(r * self.real[beam], 1) + th.sum(
             i * self.imag[beam], 1)
         bi = th.sum(i * self.real[beam], 1) - th.sum(
             r * self.imag[beam], 1)
     if squeeze:
         br = br.squeeze()
         bi = bi.squeeze()
     if trans:
         br = br.transpose(-1, -2)
         bi = bi.transpose(-1, -2)
     if cplx:
         return ComplexTensor(br, bi)
     else:
         return br, bi
예제 #12
0
def estimate_covar(mask: th.Tensor,
                   spectrogram: ComplexTensor) -> ComplexTensor:
    """
    Covariance matrices (PSD) estimation
    Args:
        mask: TF-masks (real), N x F x T
        spectrogram: complex, N x C x F x T
    Return:
        covar: complex, N x F x C x C
    """
    # N x F x C x T
    spec = spectrogram.transpose(1, 2)
    # N x F x 1 x T
    mask = mask.unsqueeze(-2)
    # N x F x C x C: einsum("...it,...jt->...ij", spec * mask, spec.conj())
    nominator = (spec * mask) @ spec.conj_transpose(-1, -2)
    # N x F x 1 x 1
    denominator = th.clamp(mask.sum(-1, keepdims=True), min=EPSILON)
    # N x F x C x C
    return nominator / denominator
예제 #13
0
def estimate_covar(mask: th.Tensor,
                   obs: ComplexTensor,
                   eps: float = EPSILON) -> ComplexTensor:
    """
    Covariance matrices estimation
    Args:
        mask (Tensor): N x F x T
        obs (ComplexTensor): N x F x C x T
    Return:
        covar (ComplexTensor): N x F x C x C
    """
    _, _, C, _ = obs.shape
    # N x F x 1 x T
    mask = mask.unsqueeze(-2)
    # N x F x C x C: einsum("...it,...jt->...ij", spec * mask, spec.conj())
    nominator = (obs * mask) @ obs.conj_transpose(-1, -2)
    # N x F x 1 x 1
    denominator = th.clamp(mask.sum(-1, keepdims=True), min=eps)
    # N x F x C x C
    Bk = C * nominator / denominator
    # N x F x C x C
    Bk = (Bk + Bk.conj_transpose(-1, -2)) / 2
    return Bk