def forward(self, xs: ComplexTensor, ilens: torch.LongTensor) \ -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """The forward function Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) mask = torch.sigmoid(mask) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def forward(self, xs: ComplexTensor, input_lengths: torch.LongTensor) \ -> torch.Tensor: assert xs.size(0) == input_lengths.size(0), (xs.size(0), input_lengths.size(0)) # xs: (B, C, T, D) C = xs.size(1) if self.feat_type == 'amplitude': # xs: (B, C, T, F) -> (B, C, T, F) xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5 elif self.feat_type == 'power': # xs: (B, C, T, F) -> (B, C, T, F) xs = xs.real ** 2 + xs.imag ** 2 elif self.feat_type == 'log_power': # xs: (B, C, T, F) -> (B, C, T, F) xs = torch.log(xs.real ** 2 + xs.imag ** 2) elif self.feat_type == 'concat': # xs: (B, C, T, F) -> (B, C, T, 2 * F) xs = torch.cat([xs.real, xs.imag], -1) else: raise NotImplementedError(f'Not implemented: {self.feat_type}') if self.model_type in ('blstm', 'lstm'): # xs: (B, C, T, F) -> xs: (B, C, T, D) xs = self.net(xs, input_lengths) elif self.model_type == 'cnn': if self.channel_independent: # xs: (B, C, T, F) -> xs: (B * C, F, T) xs = xs.view(-1, *xs.size()[2:]).transpose(1, 2) # xs: (B * C, F, T) -> xs: (B * C, D, T) xs = self.net(xs) # xs: (B * C, D, T) -> (B, C, T, D) xs = xs.transpose(1, 2).contiguous().view( -1, C, xs.size(2), xs.size(1)) else: # xs: (B, C, T, F) -> xs: (B, C, T, F) xs = self.net(xs) else: raise NotImplementedError(f'Not implemented: {self.model_type}') # xs: (B, C, T, D) -> out:(B, C, T, F) out = self.linear(xs) # Zero padding out = torch.sigmoid(out) out.masked_fill(make_pad_mask(input_lengths, out, length_dim=2), 0) return out
def wpe_one_iteration(Y: ComplexTensor, power: torch.Tensor, taps: int = 10, delay: int = 3, eps: float = 1e-10, inverse_power: bool = True) -> ComplexTensor: """WPE for one iteration Args: Y: Complex valued STFT signal with shape (..., C, T) power: : (..., T) taps: Number of filter taps delay: Delay as a guard interval, such that X does not become zero. eps: inverse_power (bool): Returns: enhanced: (..., C, T) """ assert Y.size()[:-2] == power.size()[:-1] batch_freq_size = Y.size()[:-2] Y = Y.view(-1, *Y.size()[-2:]) power = power.view(-1, power.size()[-1]) if inverse_power: inverse_power = 1 / torch.clamp(power, min=eps) else: inverse_power = power correlation_matrix, correlation_vector = \ get_correlations(Y, inverse_power, taps, delay) filter_matrix_conj = get_filter_matrix_conj(correlation_matrix, correlation_vector) enhanced = perform_filter_operation_v2(Y, filter_matrix_conj, taps, delay) enhanced = enhanced.view(*batch_freq_size, *Y.size()[-2:]) return enhanced