Exemplo n.º 1
0
    def forward(
        self,
        input: Union[torch.Tensor, ComplexTensor],
        ilens: torch.Tensor,
        additional: Optional[Dict] = None,
    ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
        """DC-CRN Separator Forward.

        Args:
            input (torch.Tensor or ComplexTensor): Encoded feature [Batch, T, F]
                                                   or [Batch, T, C, F]
            ilens (torch.Tensor): input lengths [Batch,]

        Returns:
            masked (List[Union(torch.Tensor, ComplexTensor)]): [(Batch, T, F), ...]
            ilens (torch.Tensor): (B,)
            others predicted data, e.g. masks: OrderedDict[
                'mask_spk1': torch.Tensor(Batch, Frames, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Freq),
            ]
        """
        assert is_complex(input)
        is_multichannel = input.ndim == 4
        if is_multichannel:
            feature = torch.cat([input.real, input.imag], dim=2).permute(0, 2, 1, 3)
        else:
            feature = torch.stack([input.real, input.imag], dim=1)

        masks = self.dc_crn(feature)
        masks = [new_complex_like(input, m.unbind(dim=1)) for m in masks.unbind(dim=2)]

        if self.predict_noise:
            *masks, mask_noise = masks

        if self.mode == "masking":
            if is_multichannel:
                masked = [input * m.unsqueeze(2) for m in masks]
            else:
                masked = [input * m for m in masks]
        else:
            masked = masks
            if is_multichannel:
                masks = [m.unsqueeze(2) / (input + EPS) for m in masked]
            else:
                masks = [m / (input + EPS) for m in masked]

        others = OrderedDict(
            zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
        )
        if self.predict_noise:
            mask_noise = mask_noise.unsqueeze(2) if is_multichannel else mask_noise
            if self.mode == "masking":
                others["noise1"] = input * mask_noise
            else:
                others["noise1"] = mask_noise

        return masked, ilens, others
Exemplo n.º 2
0
def _create_mask_label(mix_spec, ref_spec, noise_spec=None, mask_type="IAM"):
    """Create mask label.

    Args:
        mix_spec: ComplexTensor(B, T, [C,] F)
        ref_spec: List[ComplexTensor(B, T, [C,] F), ...]
        noise_spec: ComplexTensor(B, T, [C,] F)
            only used for IBM and IRM
        mask_type: str
    Returns:
        labels: List[Tensor(B, T, [C,] F), ...] or List[ComplexTensor(B, T, F), ...]
    """

    # Must be upper case
    mask_type = mask_type.upper()
    assert mask_type in [
        "IBM",
        "IRM",
        "IAM",
        "PSM",
        "NPSM",
        "PSM^2",
        "CIRM",
    ], f"mask type {mask_type} not supported"
    mask_label = []
    if ref_spec[0].ndim < mix_spec.ndim:
        # (B, T, F) -> (B, T, 1, F)
        ref_spec = [r.unsqueeze(2).expand_as(mix_spec.real) for r in ref_spec]
    if noise_spec is not None and noise_spec.ndim < mix_spec.ndim:
        # (B, T, F) -> (B, T, 1, F)
        noise_spec = noise_spec.unsqueeze(2).expand_as(mix_spec.real)
    for idx, r in enumerate(ref_spec):
        mask = None
        if mask_type == "IBM":
            if noise_spec is None:
                flags = [abs(r) >= abs(n) for n in ref_spec]
            else:
                flags = [abs(r) >= abs(n) for n in ref_spec + [noise_spec]]
            mask = reduce(lambda x, y: x * y, flags)
            mask = mask.int()
        elif mask_type == "IRM":
            beta = 0.5
            res_spec = sum(n for i, n in enumerate(ref_spec) if i != idx)
            if noise_spec is not None:
                res_spec += noise_spec
            mask = (abs(r).pow(2) / (abs(res_spec).pow(2) + EPS)).pow(beta)
        elif mask_type == "IAM":
            mask = abs(r) / (abs(mix_spec) + EPS)
            mask = mask.clamp(min=0, max=1)
        elif mask_type == "PSM" or mask_type == "NPSM":
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta
            mask = (
                mask.clamp(min=0, max=1)
                if mask_type == "NPSM"
                else mask.clamp(min=-1, max=1)
            )
        elif mask_type == "PSM^2":
            # This is for training beamforming masks
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta
            mask = mask.clamp(min=-1, max=1)
        elif mask_type == "CIRM":
            # Ref: Complex Ratio Masking for Monaural Speech Separation
            denominator = mix_spec.real.pow(2) + mix_spec.imag.pow(2) + EPS
            mask_real = (mix_spec.real * r.real + mix_spec.imag * r.imag) / denominator
            mask_imag = (mix_spec.real * r.imag - mix_spec.imag * r.real) / denominator
            mask = new_complex_like(mix_spec, [mask_real, mask_imag])
        assert mask is not None, f"mask type {mask_type} not supported"
        mask_label.append(mask)
    return mask_label
Exemplo n.º 3
0
def _create_mask_label(mix_spec, ref_spec, mask_type="IAM"):
    """Create mask label.

    Args:
        mix_spec: ComplexTensor(B, T, [C,] F)
        ref_spec: List[ComplexTensor(B, T, [C,] F), ...]
        mask_type: str
    Returns:
        labels: List[Tensor(B, T, [C,] F), ...] or List[ComplexTensor(B, T, F), ...]
    """

    # Must be upper case
    mask_type = mask_type.upper()
    assert mask_type in [
        "IBM",
        "IRM",
        "IAM",
        "PSM",
        "NPSM",
        "PSM^2",
        "CIRM",
    ], f"mask type {mask_type} not supported"
    mask_label = []
    for r in ref_spec:
        mask = None
        if mask_type == "IBM":
            flags = [abs(r) >= abs(n) for n in ref_spec]
            mask = reduce(lambda x, y: x * y, flags)
            mask = mask.int()
        elif mask_type == "IRM":
            # TODO(Wangyou): need to fix this,
            #  as noise referecens are provided separately
            mask = abs(r) / (sum(([abs(n) for n in ref_spec])) + EPS)
        elif mask_type == "IAM":
            mask = abs(r) / (abs(mix_spec) + EPS)
            mask = mask.clamp(min=0, max=1)
        elif mask_type == "PSM" or mask_type == "NPSM":
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta
            mask = (
                mask.clamp(min=0, max=1)
                if mask_type == "NPSM"
                else mask.clamp(min=-1, max=1)
            )
        elif mask_type == "PSM^2":
            # This is for training beamforming masks
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta
            mask = mask.clamp(min=-1, max=1)
        elif mask_type == "CIRM":
            # Ref: Complex Ratio Masking for Monaural Speech Separation
            denominator = mix_spec.real.pow(2) + mix_spec.imag.pow(2) + EPS
            mask_real = (mix_spec.real * r.real + mix_spec.imag * r.imag) / denominator
            mask_imag = (mix_spec.real * r.imag - mix_spec.imag * r.real) / denominator
            mask = new_complex_like(mix_spec, [mask_real, mask_imag])
        assert mask is not None, f"mask type {mask_type} not supported"
        mask_label.append(mask)
    return mask_label