def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """DC-CRN Separator Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [Batch, T, F] or [Batch, T, C, F] ilens (torch.Tensor): input lengths [Batch,] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(Batch, T, F), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ assert is_complex(input) is_multichannel = input.ndim == 4 if is_multichannel: feature = torch.cat([input.real, input.imag], dim=2).permute(0, 2, 1, 3) else: feature = torch.stack([input.real, input.imag], dim=1) masks = self.dc_crn(feature) masks = [new_complex_like(input, m.unbind(dim=1)) for m in masks.unbind(dim=2)] if self.predict_noise: *masks, mask_noise = masks if self.mode == "masking": if is_multichannel: masked = [input * m.unsqueeze(2) for m in masks] else: masked = [input * m for m in masks] else: masked = masks if is_multichannel: masks = [m.unsqueeze(2) / (input + EPS) for m in masked] else: masks = [m / (input + EPS) for m in masked] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: mask_noise = mask_noise.unsqueeze(2) if is_multichannel else mask_noise if self.mode == "masking": others["noise1"] = input * mask_noise else: others["noise1"] = mask_noise return masked, ilens, others
def _create_mask_label(mix_spec, ref_spec, noise_spec=None, mask_type="IAM"): """Create mask label. Args: mix_spec: ComplexTensor(B, T, [C,] F) ref_spec: List[ComplexTensor(B, T, [C,] F), ...] noise_spec: ComplexTensor(B, T, [C,] F) only used for IBM and IRM mask_type: str Returns: labels: List[Tensor(B, T, [C,] F), ...] or List[ComplexTensor(B, T, F), ...] """ # Must be upper case mask_type = mask_type.upper() assert mask_type in [ "IBM", "IRM", "IAM", "PSM", "NPSM", "PSM^2", "CIRM", ], f"mask type {mask_type} not supported" mask_label = [] if ref_spec[0].ndim < mix_spec.ndim: # (B, T, F) -> (B, T, 1, F) ref_spec = [r.unsqueeze(2).expand_as(mix_spec.real) for r in ref_spec] if noise_spec is not None and noise_spec.ndim < mix_spec.ndim: # (B, T, F) -> (B, T, 1, F) noise_spec = noise_spec.unsqueeze(2).expand_as(mix_spec.real) for idx, r in enumerate(ref_spec): mask = None if mask_type == "IBM": if noise_spec is None: flags = [abs(r) >= abs(n) for n in ref_spec] else: flags = [abs(r) >= abs(n) for n in ref_spec + [noise_spec]] mask = reduce(lambda x, y: x * y, flags) mask = mask.int() elif mask_type == "IRM": beta = 0.5 res_spec = sum(n for i, n in enumerate(ref_spec) if i != idx) if noise_spec is not None: res_spec += noise_spec mask = (abs(r).pow(2) / (abs(res_spec).pow(2) + EPS)).pow(beta) elif mask_type == "IAM": mask = abs(r) / (abs(mix_spec) + EPS) mask = mask.clamp(min=0, max=1) elif mask_type == "PSM" or mask_type == "NPSM": phase_r = r / (abs(r) + EPS) phase_mix = mix_spec / (abs(mix_spec) + EPS) # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta mask = ( mask.clamp(min=0, max=1) if mask_type == "NPSM" else mask.clamp(min=-1, max=1) ) elif mask_type == "PSM^2": # This is for training beamforming masks phase_r = r / (abs(r) + EPS) phase_mix = mix_spec / (abs(mix_spec) + EPS) # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta mask = mask.clamp(min=-1, max=1) elif mask_type == "CIRM": # Ref: Complex Ratio Masking for Monaural Speech Separation denominator = mix_spec.real.pow(2) + mix_spec.imag.pow(2) + EPS mask_real = (mix_spec.real * r.real + mix_spec.imag * r.imag) / denominator mask_imag = (mix_spec.real * r.imag - mix_spec.imag * r.real) / denominator mask = new_complex_like(mix_spec, [mask_real, mask_imag]) assert mask is not None, f"mask type {mask_type} not supported" mask_label.append(mask) return mask_label
def _create_mask_label(mix_spec, ref_spec, mask_type="IAM"): """Create mask label. Args: mix_spec: ComplexTensor(B, T, [C,] F) ref_spec: List[ComplexTensor(B, T, [C,] F), ...] mask_type: str Returns: labels: List[Tensor(B, T, [C,] F), ...] or List[ComplexTensor(B, T, F), ...] """ # Must be upper case mask_type = mask_type.upper() assert mask_type in [ "IBM", "IRM", "IAM", "PSM", "NPSM", "PSM^2", "CIRM", ], f"mask type {mask_type} not supported" mask_label = [] for r in ref_spec: mask = None if mask_type == "IBM": flags = [abs(r) >= abs(n) for n in ref_spec] mask = reduce(lambda x, y: x * y, flags) mask = mask.int() elif mask_type == "IRM": # TODO(Wangyou): need to fix this, # as noise referecens are provided separately mask = abs(r) / (sum(([abs(n) for n in ref_spec])) + EPS) elif mask_type == "IAM": mask = abs(r) / (abs(mix_spec) + EPS) mask = mask.clamp(min=0, max=1) elif mask_type == "PSM" or mask_type == "NPSM": phase_r = r / (abs(r) + EPS) phase_mix = mix_spec / (abs(mix_spec) + EPS) # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta mask = ( mask.clamp(min=0, max=1) if mask_type == "NPSM" else mask.clamp(min=-1, max=1) ) elif mask_type == "PSM^2": # This is for training beamforming masks phase_r = r / (abs(r) + EPS) phase_mix = mix_spec / (abs(mix_spec) + EPS) # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta mask = mask.clamp(min=-1, max=1) elif mask_type == "CIRM": # Ref: Complex Ratio Masking for Monaural Speech Separation denominator = mix_spec.real.pow(2) + mix_spec.imag.pow(2) + EPS mask_real = (mix_spec.real * r.real + mix_spec.imag * r.imag) / denominator mask_imag = (mix_spec.real * r.imag - mix_spec.imag * r.real) / denominator mask = new_complex_like(mix_spec, [mask_real, mask_imag]) assert mask is not None, f"mask type {mask_type} not supported" mask_label.append(mask) return mask_label