def forward(self, x: ComplexTensor, eps: float = 1e-5) -> th.Tensor: """ Args: x: complex tensor, N x C x F x T Return: y: enhanced features, N x P x G x T """ assert isinstance(x, ComplexTensor) if x.dim() not in [3, 4]: raise RuntimeError(f"Expect 3/4D tensor, got {x.dim()} instead") if x.dim() == 3: x = x[None, ...] # N x P x T x F b = self.beam(x, trans=True, cplx=True) if self.spectra_complex: # N x P x T x G w = self.proj(b) # log + abs: N x P x T x G w = (w + eps).abs() else: # N x P x T x F p = (b + eps).abs() # N x P x T x G w = th.relu(self.proj(p)) + eps z = th.log(w) # N x P x T x G if self.norm: z = self.norm(z) # N x P x T x G => N x T x P x G z = z.transpose(1, 2).contiguous() # N x T x BD z = z.view(*z.shape[:2], -1) return z
def trace(cplx_mat: ComplexTensor) -> ComplexTensor: """ Return trace of a complex matrices """ mat_size = cplx_mat.size() diag_index = th.eye(mat_size[-1], dtype=th.bool, device=cplx_mat.device).expand(*mat_size) return cplx_mat.masked_select(diag_index).view(*mat_size[:-1]).sum(-1)
def _norm_abs(self, obs: ComplexTensor) -> ComplexTensor: """ Normalize complex-valued STFTs """ mag = obs.abs() mag_norm = th.norm(mag, p=2, dim=1, keepdim=True) mag = mag / th.clamp(mag_norm, min=EPSILON) obs = ComplexTensor(mag, obs.angle(), polar=True) return obs
def _derive_weight(self, Rs: ComplexTensor, Rn: ComplexTensor, u: th.Tensor, eps: float = 1e-5) -> ComplexTensor: """ Compute mvdr beam weights Args: Rs, Rn: speech & noise covariance matrices, N x F x C x C u: reference selection vector, N x C Return: weight: N x F x C """ C = Rn.shape[-1] I = th.eye(C, device=Rn.device, dtype=Rn.dtype) Rn = Rn + I * eps # N x F x C x C Rn_inv = Rn.inverse() # N x F x C x C: einsum("...ij,...jk->...ik", Rn_inv, Rs) Rn_inv_Rs = Rn_inv @ Rs # N x F tr_Rn_inv_Rs = trace(Rn_inv_Rs) + eps # N x F x C: einsum("...fnc,...c->...fn", Rn_inv_Rs, u) Rn_inv_Rs_u = (Rn_inv_Rs * u[:, None, None, :]).sum(-1) # N x F x C weight = Rn_inv_Rs_u / tr_Rn_inv_Rs[..., None] return weight
def log_pdf(self, mask: th.Tensor, obs: ComplexTensor) -> th.Tensor: """ Compute log-pdf of the cacgmm distributions Args: mask (Tensor): N x F x T obs (ComplexTensor): N x F x C x T Return: log_pdf (Tensor) """ _, _, C, _ = obs.shape # N x F x C x C Bk = estimate_covar(mask, obs, eps=self.eps) # add to diag I = th.eye(C, device=Bk.device, dtype=Bk.dtype) Bk = Bk + I * self.eps # N x F Dk = hermitian_det(Bk, eps=self.eps) # N x F x C x C Bk_inv = Bk.inverse() # N x F x T: einsum("...xt,...xy,...yt->...t", obs.conj(), Bk_inv, obs) K = (obs.conj() * (Bk_inv @ obs)).sum(-2) K = th.clamp(K.real, min=self.eps) # N x F x T log_pdf = -C * th.log(K) - th.log(Dk[..., None]) # N x F x T return log_pdf
def forward(self, wav_pad: th.Tensor, wav_len: Optional[th.Tensor]) -> EnhReturnType: """ Args: wav_pad (Tensor): raw waveform, N x C x S or N x S wav_len (Tensor or None): number samples in wav_pad, N or None Return: feats (Tensor): spatial + spectral features, N x T x ... cplx (ComplexTensor): STFT coefficients, N x (C) x F x T num_frames (Tensor or None): number frames in each batch, N or None """ # magnitude & phase: N x C x F x T mag, pha = self.forward_stft(wav_pad) # STFT coefficients: N x C x F x T cplx = ComplexTensor(mag, pha, polar=True) feats = [] # magnitude transform if self.mag_transform: # N x (C) x T x F => N x T x F feats.append(self.mag_transform(mag)) # ipd transform if self.ipd_transform: # N x C x F x T => N x ... x T feats.append(self.ipd_transform(pha)) # concatenate: N x T x ... num_frames = self.num_frames(wav_len) if len(feats): feats = check_valid(th.cat(feats, -1), num_frames)[0] else: feats = None return feats, cplx, num_frames
def forward(self, Rs: ComplexTensor) -> th.Tensor: """ Args: Rs: complex, N x F x C x C Return: u: real, N x C """ C = Rs.shape[-1] I = th.eye(C, device=Rs.device, dtype=th.bool) # diag is zero, N x F x C Rs = Rs.masked_fill(I, 0).sum(-1) / (C - 1) # N x C x A proj = self.proj(Rs.abs().transpose(1, 2)) # N x C x 1 gvec = self.gvec(th.tanh(proj)) # N x C return tf.softmax(gvec.squeeze(-1), -1)
def forward(self, x: ComplexTensor) -> ComplexTensor: """ args: x: complex tensor return: y: complex tensor """ assert isinstance(x, ComplexTensor) r = self.real(x.real) - self.imag(x.imag) i = self.real(x.imag) + self.imag(x.real) return ComplexTensor(r, i)
def test_fixed_beamformer(batch_size, num_channels, num_bins, num_directions): beamformer = FixedBeamformer(num_directions, num_channels, num_bins) num_frames = th.randint(50, 100, (1, )).item() inp_r = th.rand(batch_size, num_channels, num_bins, num_frames) inp_i = th.rand(batch_size, num_channels, num_bins, num_frames) inp_c = ComplexTensor(inp_r, inp_i) out_b = beamformer(inp_c) assert out_b.shape == th.Size( [batch_size, num_directions, num_bins, num_frames]) out_b = beamformer(inp_c, beam=0) assert out_b.shape == th.Size([batch_size, num_bins, num_frames])
def forward(self, x: ComplexTensor, add_abs: bool = False, eps: float = 1e-5) -> Union[ComplexTensor, th.Tensor]: # x: complex tensor assert isinstance(x, ComplexTensor) xr, xi = x.real, x.imag br = self.real(xr) - self.imag(xi) bi = self.real(xi) + self.imag(xr) if not add_abs: return ComplexTensor(br, bi) else: return (br**2 + bi**2 + eps)**0.5
def forward( self, x: ComplexTensor, beam: Optional[th.Tensor] = None, squeeze: bool = False, trans: bool = False, cplx: bool = True ) -> Union[ComplexTensor, Tuple[th.Tensor, th.Tensor]]: """ Args: x (Complex Tensor): N x C x F x T beam (Tensor or None): N Return: 1) (Tensor, Tensor): N x (B) x F x T 2) (ComplexTensor): N x (B) x F x T """ r, i = x.real, x.imag if r.dim() != i.dim() and r.dim() != 4: raise RuntimeError( f"FixBeamformer accept 4D tensor, got {r.dim()}") if self.real.shape[1] != r.shape[1]: raise RuntimeError(f"Number of channels mismatch: " f"{r.shape[1]} vs {self.real.shape[1]}") if beam is None: # output all the beam br = th.sum(r.unsqueeze(1) * self.real, 2) + th.sum( i.unsqueeze(1) * self.imag, 2) bi = th.sum(i.unsqueeze(1) * self.real, 2) - th.sum( r.unsqueeze(1) * self.imag, 2) else: # output selected beam br = th.sum(r * self.real[beam], 1) + th.sum( i * self.imag[beam], 1) bi = th.sum(i * self.real[beam], 1) - th.sum( r * self.imag[beam], 1) if squeeze: br = br.squeeze() bi = bi.squeeze() if trans: br = br.transpose(-1, -2) bi = bi.transpose(-1, -2) if cplx: return ComplexTensor(br, bi) else: return br, bi
def estimate_covar(mask: th.Tensor, spectrogram: ComplexTensor) -> ComplexTensor: """ Covariance matrices (PSD) estimation Args: mask: TF-masks (real), N x F x T spectrogram: complex, N x C x F x T Return: covar: complex, N x F x C x C """ # N x F x C x T spec = spectrogram.transpose(1, 2) # N x F x 1 x T mask = mask.unsqueeze(-2) # N x F x C x C: einsum("...it,...jt->...ij", spec * mask, spec.conj()) nominator = (spec * mask) @ spec.conj_transpose(-1, -2) # N x F x 1 x 1 denominator = th.clamp(mask.sum(-1, keepdims=True), min=EPSILON) # N x F x C x C return nominator / denominator
def estimate_covar(mask: th.Tensor, obs: ComplexTensor, eps: float = EPSILON) -> ComplexTensor: """ Covariance matrices estimation Args: mask (Tensor): N x F x T obs (ComplexTensor): N x F x C x T Return: covar (ComplexTensor): N x F x C x C """ _, _, C, _ = obs.shape # N x F x 1 x T mask = mask.unsqueeze(-2) # N x F x C x C: einsum("...it,...jt->...ij", spec * mask, spec.conj()) nominator = (obs * mask) @ obs.conj_transpose(-1, -2) # N x F x 1 x 1 denominator = th.clamp(mask.sum(-1, keepdims=True), min=eps) # N x F x C x C Bk = C * nominator / denominator # N x F x C x C Bk = (Bk + Bk.conj_transpose(-1, -2)) / 2 return Bk