Beispiel #1
0
    def forward(
        self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # (B, T, F) or (B, T, C, F)
        if x.dim() not in (3, 4):
            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
        if not torch.is_tensor(ilens):
            ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)

        if x.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(x.size(2))
                h = x[:, :, ch, :]
            else:
                # Use the first channel
                h = x[:, :, 0, :]
        else:
            h = x

        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        h = h.real ** 2 + h.imag ** 2

        h, _ = self.logmel(h, ilens)
        if self.stats_file is not None:
            h, _ = self.global_mvn(h, ilens)
        if self.apply_uttmvn:
            h, _ = self.uttmvn(h, ilens)

        return h, ilens
Beispiel #2
0
def get_mvdr_vector(psd_s: ComplexTensor,
                    psd_n: ComplexTensor,
                    reference_vector: torch.Tensor,
                    eps: float = 1e-15) -> ComplexTensor:
    """Return the MVDR(Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): (..., F, C, C)
        psd_n (ComplexTensor): (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    # Add eps
    C = psd_n.size(-1)
    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
    eye = eye.view(*shape)
    psd_n += eps * eye

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector])
    return beamform_vector
Beispiel #3
0
    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, (C,) F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
                is_torch_1_9_plus and not torch.is_complex(input)):
            raise TypeError("Only support complex tensors for stft decoder")

        bs = input.size(0)
        if input.dim() == 4:
            multi_channel = True
            # input: (Batch, T, C, F) -> (Batch * C, T, F)
            input = input.transpose(1, 2).reshape(-1, input.size(1),
                                                  input.size(3))
        else:
            multi_channel = False

        wav, wav_lens = self.stft.inverse(input, ilens)

        if multi_channel:
            # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
            wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)

        return wav, wav_lens
Beispiel #4
0
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        input_stft, feats_lens = self.stft(input, input_lengths)

        assert input_stft.dim() >= 4, input_stft.shape
        # "2" refers to the real/imag parts of Complex
        assert input_stft.shape[-1] == 2, input_stft.shape

        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])

        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(input_stft.size(2))
                input_stft = input_stft[:, :, ch, :]
            else:
                # Use the first channel
                input_stft = input_stft[:, :, 0, :]

        # 4. STFT -> Power spectrum
        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        input_power = input_stft.real ** 2 + input_stft.imag ** 2

        # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
        # input_power: (Batch, [Channel,] Length, Freq)
        #       -> input_feats: (Batch, Length, Dim)
        input_feats, _ = self.logmel(input_power, feats_lens)

        return input_feats, feats_lens
Beispiel #5
0
    def forward(
        self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray,
                                             List[int]]
    ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
        assert len(x) == len(ilens), (len(x), len(ilens))
        # (B, T, F) or (B, T, C, F)
        if x.dim() not in (3, 4):
            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
        if not torch.is_tensor(ilens):
            ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)

        mask = None
        h = x
        if h.dim() == 4:
            if self.training:
                choices = [(False,
                            False)] if not self.use_frontend_for_all else []
                if self.use_wpe:
                    choices.append((True, False))

                if self.use_beamformer:
                    choices.append((False, True))

                use_wpe, use_beamformer = choices[numpy.random.randint(
                    len(choices))]

            else:
                use_wpe = self.use_wpe
                use_beamformer = self.use_beamformer

            # 1. WPE
            if use_wpe:
                # h: (B, T, C, F) -> h: (B, T, C, F)
                h, ilens, mask = self.wpe(h, ilens)

            # 2. Beamformer
            if use_beamformer:
                # h: (B, T, C, F) -> h: (B, T, F)
                h, ilens, mask = self.beamformer(h, ilens)

        return h, ilens, mask
Beispiel #6
0
def tik_reg(mat: ComplexTensor,
            reg: float = 1e-8,
            eps: float = 1e-8) -> ComplexTensor:
    """Perform Tikhonov regularization (only modifying real part).

    Args:
        mat (ComplexTensor): input matrix (..., C, C)
        reg (float): regularization factor
        eps (float)
    Returns:
        ret (ComplexTensor): regularized matrix (..., C, C)
    """
    # Add eps
    C = mat.size(-1)
    eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
    shape = [1 for _ in range(mat.dim() - 2)] + [C, C]
    eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1)
    with torch.no_grad():
        epsilon = FC.trace(mat).real[..., None, None] * reg
        # in case that correlation_matrix is all-zero
        epsilon = epsilon + eps
    mat = mat + epsilon * eye
    return mat
Beispiel #7
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor,
                           eps: float = 1e-10) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)
        eps:

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    eye = torch.eye(correlation_matrix.size(-1),
                    dtype=correlation_matrix.dtype,
                    device=correlation_matrix.device)
    shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \
        correlation_matrix.shape[-2:]
    eye = eye.view(*shape)
    correlation_matrix += eps * eye

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Beispiel #8
0
    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0], input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        enhanced = input_spectrum
        masks = OrderedDict()

        if input_spectrum.dim() == 3:
            # single-channel input
            if self.use_wpe:
                # (B, T, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum.unsqueeze(-2), flens)
                enhanced = enhanced.squeeze(-2)
                if mask_w is not None:
                    masks["dereverb"] = mask_w.squeeze(-2)

        elif input_spectrum.dim() == 4:
            # multi-channel input
            # 1. WPE
            if self.use_wpe:
                # (B, T, C, F)
                enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                if mask_w is not None:
                    masks["dereverb"] = mask_w

            # 2. Beamformer
            if self.use_beamformer:
                # enhanced: (B, T, C, F) -> (B, T, F)
                enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

        else:
            raise ValueError(
                "Invalid spectrum dimension: {}".format(input_spectrum.shape)
            )

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag], dim=-1).float()
        return enhanced, flens, masks
Beispiel #9
0
    def forward(
        self, input: ComplexTensor, ilens: torch.Tensor
    ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]:
        """Forward.

        Args:
            input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech (single-channel): List[ComplexTensor]
            output lengths
            other predcited data: OrderedDict[
                'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq),
                'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input.dim() in (3, 4), input.dim()
        enhanced = input
        others = OrderedDict()

        if (self.training and self.loss_type is not None
                and self.loss_type.startswith("mask")):
            # Only estimating masks during training for saving memory
            if self.use_wpe:
                if input.dim() == 3:
                    mask_w, ilens = self.wpe.predict_mask(
                        input.unsqueeze(-2), ilens)
                    mask_w = mask_w.squeeze(-2)
                elif input.dim() == 4:
                    mask_w, ilens = self.wpe.predict_mask(input, ilens)

                if mask_w is not None:
                    if isinstance(enhanced, list):
                        # single-source WPE
                        for spk in range(self.num_spk):
                            others["mask_dereverb{}".format(spk +
                                                            1)] = mask_w[spk]
                    else:
                        # multi-source WPE
                        others["mask_dereverb1"] = mask_w

            if self.use_beamformer and input.dim() == 4:
                others_b, ilens = self.beamformer.predict_mask(input, ilens)
                for spk in range(self.num_spk):
                    others["mask_spk{}".format(spk + 1)] = others_b[spk]
                if len(others_b) > self.num_spk:
                    others["mask_noise1"] = others_b[self.num_spk]

            return None, ilens, others

        else:
            powers = None
            # Performing both mask estimation and enhancement
            if input.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(
                        input.unsqueeze(-2), ilens)
                    if isinstance(enhanced, list):
                        # single-source WPE
                        enhanced = [enh.squeeze(-2) for enh in enhanced]
                        if mask_w is not None:
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk].squeeze(-2)
                    else:
                        # multi-source WPE
                        enhanced = enhanced.squeeze(-2)
                        if mask_w is not None:
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, ilens, mask_w, powers = self.wpe(input, ilens)
                    if mask_w is not None:
                        if isinstance(enhanced, list):
                            # single-source WPE
                            for spk in range(self.num_spk):
                                key = "dereverb{}".format(spk + 1)
                                others[key] = enhanced[spk]
                                others["mask_" + key] = mask_w[spk]
                        else:
                            # multi-source WPE
                            others["dereverb1"] = enhanced
                            others["mask_dereverb1"] = mask_w.squeeze(-2)

                # 2. Beamformer
                if self.use_beamformer:
                    if (not self.beamformer.beamformer_type.startswith("wmpdr")
                            or not self.beamformer.beamformer_type.startswith(
                                "wpd") or not self.shared_power
                            or (self.wpe.nmask == 1 and self.num_spk > 1)):
                        powers = None

                    # enhanced: (B, T, C, F) -> (B, T, F)
                    if isinstance(enhanced, list):
                        # outputs of single-source WPE
                        raise NotImplementedError(
                            "Single-source WPE is not supported with beamformer "
                            "in multi-speaker cases.")
                    else:
                        # output of multi-source WPE
                        enhanced, ilens, others_b = self.beamformer(
                            enhanced, ilens, powers=powers)
                    for spk in range(self.num_spk):
                        others["mask_spk{}".format(spk + 1)] = others_b[spk]
                    if len(others_b) > self.num_spk:
                        others["mask_noise1"] = others_b[self.num_spk]

        if not isinstance(enhanced, list):
            enhanced = [enhanced]

        return enhanced, ilens, others
Beispiel #10
0
    def forward(self, input: torch.Tensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (torch.Tensor): mixed speech [Batch, Nsample, Channel]
            ilens (torch.Tensor): input lengths [Batch]

        Returns:
            enhanced speech  (single-channel):
                torch.Tensor or List[torch.Tensor]
            output lengths
            predcited masks: OrderedDict[
                'dereverb': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk1': torch.Tensor(Batch, Frames, Channel, Freq),
                'spk2': torch.Tensor(Batch, Frames, Channel, Freq),
                ...
                'spkn': torch.Tensor(Batch, Frames, Channel, Freq),
                'noise1': torch.Tensor(Batch, Frames, Channel, Freq),
            ]
        """
        # wave -> stft -> magnitude specturm
        input_spectrum, flens = self.stft(input, ilens)
        # (Batch, Frames, Freq) or (Batch, Frames, Channels, Freq)
        input_spectrum = ComplexTensor(input_spectrum[..., 0],
                                       input_spectrum[..., 1])
        if self.normalize_input:
            input_spectrum = input_spectrum / abs(input_spectrum).max()

        # Shape of input spectrum must be (B, T, F) or (B, T, C, F)
        assert input_spectrum.dim() in (3, 4), input_spectrum.dim()
        enhanced = input_spectrum
        masks = OrderedDict()

        if self.training and self.loss_type.startswith("mask"):
            # Only estimating masks for training
            if self.use_wpe:
                if input_spectrum.dim() == 3:
                    mask_w, flens = self.wpe.predict_mask(
                        input_spectrum.unsqueeze(-2), flens)
                    mask_w = mask_w.squeeze(-2)
                elif input_spectrum.dim() == 4:
                    if self.use_beamformer:
                        enhanced, flens, mask_w = self.wpe(
                            input_spectrum, flens)
                    else:
                        mask_w, flens = self.wpe.predict_mask(
                            input_spectrum, flens)

                if mask_w is not None:
                    masks["dereverb"] = mask_w

            if self.use_beamformer and input_spectrum.dim() == 4:
                masks_b, flens = self.beamformer.predict_mask(enhanced, flens)
                for spk in range(self.num_spk):
                    masks["spk{}".format(spk + 1)] = masks_b[spk]
                if len(masks_b) > self.num_spk:
                    masks["noise1"] = masks_b[self.num_spk]

            return None, flens, masks

        else:
            # Performing both mask estimation and enhancement
            if input_spectrum.dim() == 3:
                # single-channel input (B, T, F)
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(
                        input_spectrum.unsqueeze(-2), flens)
                    enhanced = enhanced.squeeze(-2)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w.squeeze(-2)
            else:
                # multi-channel input (B, T, C, F)
                # 1. WPE
                if self.use_wpe:
                    enhanced, flens, mask_w = self.wpe(input_spectrum, flens)
                    if mask_w is not None:
                        masks["dereverb"] = mask_w

                # 2. Beamformer
                if self.use_beamformer:
                    # enhanced: (B, T, C, F) -> (B, T, F)
                    enhanced, flens, masks_b = self.beamformer(enhanced, flens)
                    for spk in range(self.num_spk):
                        masks["spk{}".format(spk + 1)] = masks_b[spk]
                    if len(masks_b) > self.num_spk:
                        masks["noise1"] = masks_b[self.num_spk]

        # Convert ComplexTensor to torch.Tensor
        # (B, T, F) -> (B, T, F, 2)
        if isinstance(enhanced, list):
            # multi-speaker output
            enhanced = [
                torch.stack([enh.real, enh.imag], dim=-1) for enh in enhanced
            ]
        else:
            # single-speaker output
            enhanced = torch.stack([enhanced.real, enhanced.imag],
                                   dim=-1).float()
        return enhanced, flens, masks