def forward( self, xs: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """The forward function Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) if self.nonlinear == "sigmoid": mask = torch.sigmoid(mask) elif self.nonlinear == "relu": mask = torch.relu(mask) elif self.nonlinear == "tanh": mask = torch.tanh(mask) elif self.nonlinear == "crelu": mask = torch.clamp(mask, min=0, max=1) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, C, T, F), double precision ilens: (B,) Returns: data: (B, C, T, F), double precision ilens: (B,) """ # (B, T, C, F) -> (B, F, C, T) enhanced = data = data.permute(0, 3, 2, 1) mask = None for i in range(self.iterations): # Calculate power: (..., C, T) power = enhanced.real**2 + enhanced.imag**2 if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) (mask, ), _ = self.mask_est(enhanced, ilens) if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-1)[..., None] # (..., C, T) * (..., C, T) -> (..., C, T) power = power * mask # Averaging along the channel axis: (..., C, T) -> (..., T) power = power.mean(dim=-2) # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = wpe_one_iteration( data.contiguous().double(), power.double(), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) enhanced = enhanced.type(data.dtype) enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) # (B, F, C, T) -> (B, T, C, F) enhanced = enhanced.permute(0, 3, 2, 1) if mask is not None: mask = mask.transpose(-1, -3) return enhanced, ilens, mask
def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """DNN_WPE forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, T, C, F) ilens: (B,) Returns: enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) ilens: (B,) masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) power (List[torch.Tensor]): (B, F, T) """ # (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) enhanced = [data for i in range(self.nmask)] masks = None power = None for i in range(self.iterations): # Calculate power: (..., C, T) power = [enh.real**2 + enh.imag**2 for enh in enhanced] if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) masks, _ = self.mask_est(data, ilens) # floor masks to increase numerical stability if self.mask_flooring: masks = [m.clamp(min=self.flooring_thres) for m in masks] if self.normalization: # Normalize along T masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] # (..., C, T) * (..., C, T) -> (..., C, T) power = [p * masks[i] for i, p in enumerate(power)] # Averaging along the channel axis: (..., C, T) -> (..., T) power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] # enhanced: (..., C, T) -> (..., C, T) # NOTE(kamo): Calculate in double precision enhanced = [ wpe_one_iteration( data.contiguous().double(), p.double(), taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) for p in power ] enhanced = [ enh.to(dtype=data.dtype).masked_fill( make_pad_mask(ilens, enh.real), 0) for enh in enhanced ] # (B, F, C, T) -> (B, T, C, F) enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] if masks is not None: masks = ([m.transpose(-1, -3) for m in masks] if self.nmask > 1 else masks[0].transpose(-1, -3)) if self.nmask == 1: enhanced = enhanced[0] return enhanced, ilens, masks, power