def predict_mask( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """Predict masks for beamforming. Args: data (torch.complex64/ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: masks (torch.Tensor): (B, T, C, F) ilens (torch.Tensor): (B,) """ masks, _ = self.mask(to_float(data.permute(0, 3, 2, 1)), ilens) # (B, F, C, T) -> (B, T, C, F) masks = [m.transpose(-1, -3) for m in masks] return masks, ilens
def predict_mask( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor) -> Tuple[torch.Tensor, torch.LongTensor]: """Predict mask for WPE dereverberation. Args: data (torch.complex64/ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) ilens (torch.Tensor): (B,) """ if self.use_dnn_mask: masks, ilens = self.mask_est(to_float(data.permute(0, 3, 2, 1)), ilens) # (B, F, C, T) -> (B, T, C, F) masks = [m.transpose(-1, -3) for m in masks] if self.nmask == 1: masks = masks[0] else: masks = None return masks, ilens