Example #1
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Example #2
0
    def forward(
        self, xs: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """The forward function

        Args:
            xs: (B, F, C, T)
            ilens: (B,)
        Returns:
            hs (torch.Tensor): The hidden vector (B, F, C, T)
            masks: A tuple of the masks. (B, F, C, T)
            ilens: (B,)
        """
        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
        _, _, C, input_length = xs.size()
        # (B, F, C, T) -> (B, C, T, F)
        xs = xs.permute(0, 2, 3, 1)

        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
        xs = (xs.real**2 + xs.imag**2)**0.5
        # xs: (B, C, T, F) -> xs: (B * C, T, F)
        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
        # ilens: (B,) -> ilens_: (B * C)
        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)

        # xs: (B * C, T, F) -> xs: (B * C, T, D)
        xs, _, _ = self.brnn(xs, ilens_)
        # xs: (B * C, T, D) -> xs: (B, C, T, D)
        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))

        masks = []
        for linear in self.linears:
            # xs: (B, C, T, D) -> mask:(B, C, T, F)
            mask = linear(xs)

            if self.nonlinear == "sigmoid":
                mask = torch.sigmoid(mask)
            elif self.nonlinear == "relu":
                mask = torch.relu(mask)
            elif self.nonlinear == "tanh":
                mask = torch.tanh(mask)
            elif self.nonlinear == "crelu":
                mask = torch.clamp(mask, min=0, max=1)
            # Zero padding
            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)

            # (B, C, T, F) -> (B, F, C, T)
            mask = mask.permute(0, 3, 1, 2)

            # Take cares of multi gpu cases: If input_length > max(ilens)
            if mask.size(-1) < input_length:
                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
            masks.append(mask)

        return tuple(masks), ilens
Example #3
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, C, T, F), double precision
            ilens: (B,)
        Returns:
            data: (B, C, T, F), double precision
            ilens: (B,)
        """
        # (B, T, C, F) -> (B, F, C, T)
        enhanced = data = data.permute(0, 3, 2, 1)
        mask = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = enhanced.real**2 + enhanced.imag**2
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                (mask, ), _ = self.mask_est(enhanced, ilens)
                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-1)[..., None]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = power * mask

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = power.mean(dim=-2)

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = wpe_one_iteration(
                data.contiguous().double(),
                power.double(),
                taps=self.taps,
                delay=self.delay,
                inverse_power=self.inverse_power,
            )
            enhanced = enhanced.type(data.dtype)
            enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = enhanced.permute(0, 3, 2, 1)
        if mask is not None:
            mask = mask.transpose(-1, -3)
        return enhanced, ilens, mask
Example #4
0
    def predict_mask(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """Predict masks for beamforming

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            masks (torch.Tensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        """
        masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens)
        # (B, F, C, T) -> (B, T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return masks, ilens
Example #5
0
    def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
            -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)

        """
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: (B, F, C, T)
        (mask_speech, mask_noise), _ = self.mask(data, ilens)

        psd_speech = get_power_spectral_density_matrix(data, mask_speech)
        psd_noise = get_power_spectral_density_matrix(data, mask_noise)

        # u: (B, C)
        if self.ref_channel < 0:
            u, _ = self.ref(psd_speech, ilens)
        else:
            # (optional) Create onehot vector for fixed reference microphone
            u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)),
                            device=data.device)
            u[..., self.ref_channel].fill_(1)

        ws = get_mvdr_vector(psd_speech, psd_noise, u)
        enhanced = apply_beamforming_vector(ws, data)

        # (..., F, T) -> (..., T, F)
        enhanced = enhanced.transpose(-1, -2)
        mask_speech = mask_speech.transpose(-1, -3)

        return enhanced, ilens, mask_speech
Example #6
0
    def predict_mask(
            self, data: ComplexTensor,
            ilens: torch.LongTensor) -> Tuple[torch.Tensor, torch.LongTensor]:
        """Predict mask for WPE dereverberation

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            masks (torch.Tensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        """
        if self.use_dnn_mask:
            (mask, ), ilens = self.mask_est(
                data.permute(0, 3, 2, 1).float(), ilens)
            # (B, F, C, T) -> (B, T, C, F)
            mask = mask.transpose(-1, -3)
        else:
            mask = None
        return mask, ilens
Example #7
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor,
                           eps: float = 1e-10) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)
        eps:

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    eye = torch.eye(correlation_matrix.size(-1),
                    dtype=correlation_matrix.dtype,
                    device=correlation_matrix.device)
    shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \
        correlation_matrix.shape[-2:]
    eye = eye.view(*shape)
    correlation_matrix += eps * eye

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Example #8
0
    def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
            -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)

        """
        def apply_beamforming(data, ilens, psd_speech, psd_noise):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech, ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            ws = get_mvdr_vector(psd_speech, psd_noise, u)
            enhanced = apply_beamforming_vector(ws, data)

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: (B, F, C, T)
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks)

        if self.nmask == 2:  # (mask_speech, mask_noise)
            mask_speech, mask_noise = masks

            psd_speech = get_power_spectral_density_matrix(data, mask_speech)
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)

            enhanced, ws = apply_beamforming(data, ilens, psd_speech,
                                             psd_noise)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
            mask_speech = mask_speech.transpose(-1, -3)
        else:  # multi-speaker case: (mask_speech1, ..., mask_noise)
            mask_speech = list(masks[:-1])
            mask_noise = masks[-1]

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)

            enhanced = []
            ws = []
            for i in range(self.nmask - 1):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                enh, w = apply_beamforming(data, ilens, psd_speech,
                                           sum(psd_speeches) + psd_noise)
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                mask_speech[i] = mask_speech[i].transpose(-1, -3)

                enhanced.append(enh)
                ws.append(w)

        return enhanced, ilens, mask_speech
Example #9
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F), double precision
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data, ilens, psd_speech, psd_n, beamformer_type):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.float(), ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                device=data.device)
                u[..., self.ref_channel].fill_(1)

            if beamformer_type in ("mpdr", "mvdr"):
                ws = get_mvdr_vector(psd_speech, psd_n, u.double())
                enhanced = apply_beamforming_vector(ws, data)
            elif beamformer_type == "wpd":
                ws = get_WPD_filter_v2(psd_speech, psd_n, u.double())
                enhanced = perform_WPD_filtering(ws, data, self.bdelay,
                                                 self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    beamformer_type))

            return enhanced, ws

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data.float(), ilens)
        assert self.nmask == len(masks)
        # floor masks with self.eps to increase numerical stability
        masks = [torch.clamp(m, min=self.eps) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            psd_speech = get_power_spectral_density_matrix(
                data, mask_speech.double())
            if self.beamformer_type == "mvdr":
                # psd of noise
                psd_n = get_power_spectral_density_matrix(
                    data, mask_noise.double())
            elif self.beamformer_type == "mpdr":
                # psd of observed signal
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power_speech = (data.real**2 +
                                data.imag**2) * mask_speech.double()
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speech = power_speech.mean(dim=-2)
                inverse_power = 1 / torch.clamp(power_speech, min=self.eps)
                # covariance of expanded observed speech
                psd_n = get_covariances(data,
                                        inverse_power,
                                        self.bdelay,
                                        self.btaps,
                                        get_vector=False)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_n,
                                             self.beamformer_type)

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            psd_speeches = [
                get_power_spectral_density_matrix(data, mask)
                for mask in mask_speech
            ]
            if self.beamformer_type == "mvdr":
                # psd of noise
                if mask_noise is not None:
                    psd_n = get_power_spectral_density_matrix(data, mask_noise)
            elif self.beamformer_type == "mpdr":
                # psd of observed speech
                psd_n = FC.einsum("...ct,...et->...ce", [data, data.conj()])
            elif self.beamformer_type == "wpd":
                # Calculate power: (..., C, T)
                power = data.real**2 + data.imag**2
                power_speeches = [power * mask for mask in mask_speech]
                # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
                power_speeches = [ps.mean(dim=-2) for ps in power_speeches]
                inverse_poweres = [
                    1 / torch.clamp(ps, min=self.eps) for ps in power_speeches
                ]
                # covariance of expanded observed speech
                psd_n = [
                    get_covariances(data,
                                    inv_ps,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_ps in inverse_poweres
                ]
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            enhanced = []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    psd_noise = sum(psd_speeches)
                    if mask_noise is not None:
                        psd_noise = psd_noise + psd_n

                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_noise, self.beamformer_type)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(data, ilens, psd_speech, psd_n,
                                               self.beamformer_type)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(data, ilens, psd_speech,
                                               psd_n[i], self.beamformer_type)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
Example #10
0
    def forward(self,
                data: ComplexTensor, ilens: torch.LongTensor=None,
                return_wpe: bool=True) -> Tuple[Optional[ComplexTensor],
                                                torch.Tensor]:
        if ilens is None:
            ilens = torch.full((data.size(0),), data.size(2),
                               dtype=torch.long, device=data.device)

        r = -self.rcontext if self.rcontext != 0 else None
        enhanced = data[:, :, self.lcontext:r, :]

        if self.lcontext != 0 or self.rcontext != 0:
            assert all(ilens[0] == i for i in ilens)

            # Create context window (a.k.a Splicing)
            if self.model_type in ('blstm', 'lstm'):
                width = data.size(2) - self.lcontext - self.rcontext
                # data: (B, C, l + w + r, F)
                indices = [i + j for i in range(width)
                           for j in range(1 + self.lcontext + self.rcontext)]
                _y = data[:, :, indices]
                # data: (B, C, l, (1 + w + r), F)
                data = _y.view(
                    data.size(0), data.size(1),
                    width, (1 + self.lcontext + self.rcontext) * data.size(3))
                ilens = torch.full((data.size(0),), width,
                                   dtype=torch.long, device=data.device)
                del _y

        for i in range(self.iterations):
            power = enhanced.real ** 2 + enhanced.imag ** 2
            # Calculate power: (B, C, T, Context, F)
            if i == 0 and self.use_dnn:
                # mask: (B, C, T, F)
                mask = self.estimator(data, ilens)
                if mask.size(2) != power.size(2):
                    assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext)
                    r = -self.rcontext if self.rcontext != 0 else None
                    mask = mask[:, :, self.lcontext:r, :]

                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-2)[..., None]
                if self.out_type == 'mask':
                    power = power * mask
                else:
                    power = mask

                    if self.out_type == 'amplitude':
                        power = power ** 2
                    elif self.out_type == 'log_power':
                        power = power.exp()
                    elif self.out_type == 'power':
                        pass
                    else:
                        raise NotImplementedError(self.out_type)

            if not return_wpe:
                return None, power

            # power: (B, C, T, F) -> _power: (B, F, T)
            _power = power.mean(dim=1).transpose(-1, -2).contiguous()

            # data: (B, C, T, F) -> _data: (B, F, C, T)
            _data = data.permute(0, 3, 1, 2).contiguous()
            # _enhanced: (B, F, C, T)
            _enhanced_real = []
            _enhanced_imag = []
            for d, p, l in zip(_data, _power, ilens):
                # e: (F, C, T) -> (T, C, F)
                e = wpe_one_iteration(
                    d[..., :l], p[..., :l],
                    taps=self.taps, delay=self.delay,
                    inverse_power=self.inverse_power).transpose(0, 2)
                _enhanced_real.append(e.real)
                _enhanced_imag.append(e.imag)
            # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T)
            _enhanced_real = pad_sequence(_enhanced_real,
                                          batch_first=True).transpose(1, 3)
            _enhanced_imag = pad_sequence(_enhanced_imag,
                                          batch_first=True).transpose(1, 3)
            _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag)

            # enhanced: (B, F, C, T) -> (B, C, T, F)
            enhanced = _enhanced.permute(0, 2, 3, 1)

        # enhanced: (B, C, T, F), power: (B, C, T, F)
        return enhanced, power
Example #11
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """DNN_WPE forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, T, C, F)
            ilens: (B,)
        Returns:
            enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            ilens: (B,)
            masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            power (List[torch.Tensor]): (B, F, T)
        """
        # (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        enhanced = [data for i in range(self.nmask)]
        masks = None
        power = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = [enh.real**2 + enh.imag**2 for enh in enhanced]
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                masks, _ = self.mask_est(data, ilens)
                # floor masks to increase numerical stability
                if self.mask_flooring:
                    masks = [m.clamp(min=self.flooring_thres) for m in masks]
                if self.normalization:
                    # Normalize along T
                    masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = [p * masks[i] for i, p in enumerate(power)]

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = [
                wpe_one_iteration(
                    data.contiguous().double(),
                    p.double(),
                    taps=self.taps,
                    delay=self.delay,
                    inverse_power=self.inverse_power,
                ) for p in power
            ]
            enhanced = [
                enh.to(dtype=data.dtype).masked_fill(
                    make_pad_mask(ilens, enh.real), 0) for enh in enhanced
            ]

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
        if masks is not None:
            masks = ([m.transpose(-1, -3) for m in masks]
                     if self.nmask > 1 else masks[0].transpose(-1, -3))
        if self.nmask == 1:
            enhanced = enhanced[0]

        return enhanced, ilens, masks, power
Example #12
0
    def forward(
        self,
        data: ComplexTensor,
        ilens: torch.LongTensor,
        powers: Union[List[torch.Tensor], None] = None,
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """
        def apply_beamforming(data,
                              ilens,
                              psd_n,
                              psd_speech,
                              psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C)
                psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (ComplexTensor): (B, F, T)
                ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(*(data.size()[:-3] + (data.size(-2), )),
                                    device=data.device,
                                    dtype=torch.double)
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden",
                                          "wmpdr_souden"):
                ws = get_mvdr_vector(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(ws, data.double(),
                                                 self.bdelay, self.btaps)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = data.double()

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = (power_input * mask_speech.double()).mean(dim=-2)
                else:
                    assert len(powers) == 1, len(powers)
                    powers = powers[0]
                inverse_power = 1 / torch.clamp(powers, min=self.eps)

            psd_speech = get_power_spectral_density_matrix(
                data_d, mask_speech.double())
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type == "mvdr":
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_noise,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mvdr_souden":
                enhanced, ws = apply_beamforming(data, ilens, psd_noise,
                                                 psd_speech)
            elif self.beamformer_type == "mpdr":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "mpdr_souden":
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wmpdr":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wmpdr_souden":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :],
                     data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed,
                                                 psd_speech)
            elif self.beamformer_type == "wpd":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data,
                                                 ilens,
                                                 psd_observed_bar,
                                                 psd_speech,
                                                 psd_distortion=psd_noise)
            elif self.beamformer_type == "wpd_souden":
                psd_observed_bar = get_covariances(data_d,
                                                   inverse_power,
                                                   self.bdelay,
                                                   self.btaps,
                                                   get_vector=False)
                enhanced, ws = apply_beamforming(data, ilens, psd_observed_bar,
                                                 psd_speech)
            else:
                raise ValueError("Not supporting beamformer_type={}".format(
                    self.beamformer_type))

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            if self.beamformer_type.startswith(
                    "wmpdr") or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real**2 + data_d.imag**2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = [(power_input * m.double()).mean(dim=-2)
                              for m in mask_speech]
                else:
                    assert len(powers) == self.num_spk, len(powers)
                inverse_power = [
                    1 / torch.clamp(p, min=self.eps) for p in powers
                ]

            psd_speeches = [
                get_power_spectral_density_matrix(data_d, mask.double())
                for mask in mask_speech
            ]
            if mask_noise is not None and (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double())
            if self.beamformer_type in ("mpdr", "mpdr_souden"):
                psd_observed = FC.einsum("...ct,...et->...ce",
                                         [data_d, data_d.conj()])
            elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
                psd_observed = [
                    FC.einsum(
                        "...ct,...et->...ce",
                        [data_d * inv_p[..., None, :],
                         data_d.conj()],
                    ) for inv_p in inverse_power
                ]
            elif self.beamformer_type in ("wpd", "wpd_souden"):
                psd_observed_bar = [
                    get_covariances(data_d,
                                    inv_p,
                                    self.bdelay,
                                    self.btaps,
                                    get_vector=False)
                    for inv_p in inverse_power
                ]

            enhanced, ws = [], []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                if (self.beamformer_type == "mvdr_souden"
                        or not self.beamformer_type.endswith("_souden")):
                    psd_noise_i = (psd_noise + sum(psd_speeches) if mask_noise
                                   is not None else sum(psd_speeches))
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    enh, w = apply_beamforming(data,
                                               ilens,
                                               psd_noise_i,
                                               psd_speech,
                                               psd_distortion=psd_noise_i)
                elif self.beamformer_type == "mvdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_noise_i,
                                               psd_speech)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed,
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "mpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed,
                                               psd_speech)
                elif self.beamformer_type == "wmpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wmpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed[i],
                                               psd_speech)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed_bar[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wpd_souden":
                    enh, w = apply_beamforming(data, ilens,
                                               psd_observed_bar[i], psd_speech)
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(
                            self.beamformer_type))
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks
Example #13
0
    def forward(self,
                xs: ComplexTensor,
                ts: ComplexTensor,
                ilens: torch.LongTensor,
                loss_types: Union[str, Sequence[str]] = 'power_mse',
                ref_channel: int = 0) -> Dict[str, torch.Tensor]:
        # xs: (B, C, T, F), ts: (B, T, F)
        if isinstance(loss_types, str):
            loss_types = [loss_types]

        # ys: (B, C, T, F), power: (B, C, T, F)
        for loss_type in loss_types:
            if 'dnnwpe' in loss_type:
                return_wpe = True
                break
        else:
            return_wpe = False
        ys, power = self.model(xs, ilens, return_wpe=return_wpe)

        r = -self.rcontext if self.rcontext != 0 else None
        xs = xs[:, :, self.lcontext:r, :]

        if ys is not None:
            assert xs.shape == ys.shape, (xs.shape, ys.shape)
        assert xs.shape == power.shape, (xs.shape, power.shape)

        uts = None
        uys = None
        upower = None
        ys_time = None
        ts_time = None
        xs_time = None

        loss_dict = OrderedDict()
        for loss_type in loss_types:
            if loss_type == 'dnnwpe_power_mse':
                if uys is None:
                    uys = FC.cat(unpad(ys, ilens, length_dim=2), dim=1)
                if uts is None:
                    uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1)

                _ys = uys.real**2 + uys.imag**2
                _ts = uts.real**2 + uts.imag**2

                _ys = _ys.log()
                _ts = _ts.log()
                loss = mse_loss(_ys, _ts)

            elif loss_type == 'dnnwpe_mse':
                if uys is None:
                    uys = FC.cat(unpad(ys, ilens, length_dim=2), dim=1)
                if uts is None:
                    uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1)

                _ys = torch.cat([uys.real, uys.imag], dim=-1)
                _ts = torch.cat([uts.real, uts.imag], dim=-1)
                loss = mse_loss(_ys, _ts)

            elif loss_type == 'power_mse':
                if upower is None:
                    upower = torch.cat(unpad(power, ilens, length_dim=2),
                                       dim=1)
                if uts is None:
                    uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1)

                _ts = uts.real**2 + uts.imag**2
                _upower = upower
                _ts = _ts

                loss = mse_loss(_upower, _ts)

            # For evaluation as not differentiable
            elif loss_type == 'dnnwpe_stoi':

                # Use the first channel only to make faster calculation
                if ys_time is None:
                    # _ys: List[torch.Tensor]: B x [C, T, F]
                    _ys = unpad(ys, ilens, length_dim=2)
                    # ys_time: List[np.ndarray]: B x [T]
                    ys_time = [
                        self.stft_func.istft(_y[0].cpu().numpy().T)
                        for _y in _ys
                    ]
                if ts_time is None:
                    # _ts: List[torch.Tensor]: B x [C, T, F]
                    _ts = unpad(ts, ilens, length_dim=2)
                    # ts_time: List[np.ndarray]: B x [T]
                    ts_time = [
                        self.stft_func.istft(_t[0].cpu().numpy().T)
                        for _t in _ts
                    ]

                _losses = []

                for _y, _t in zip(ys_time, ts_time):
                    # Single channel only
                    _losses.append(stoi(_t, _y, self.stft_func.fs))
                loss = torch.tensor(numpy.mean(_losses))

            # For evaluation as not differentiable
            elif loss_type == 'dnnwpe_pesq':

                # Use the first channel only to make faster calculation
                if ys_time is None:
                    # _ys: List[torch.Tensor]: B x [C, T, F]
                    _ys = unpad(ys, ilens, length_dim=2)
                    # ys_time: List[np.ndarray]: B x [T]
                    ys_time = [
                        self.stft_func.istft(_y[0].cpu().numpy().T)
                        for _y in _ys
                    ]
                if ts_time is None:
                    # _ts: List[torch.Tensor]: B x [C, T, F]
                    _ts = unpad(ts, ilens, length_dim=2)
                    # ts_time: List[np.ndarray]: B x [T]
                    ts_time = [
                        self.stft_func.istft(_t[0].cpu().numpy().T)
                        for _t in _ts
                    ]

                _fns = []
                # PESQ via subprocess can be parallerize by threading
                e = ThreadPoolExecutor(self.pesq_nworker)
                for _y, _t in zip(ys_time, ts_time):
                    _y *= numpy.iinfo(numpy.int16).max - 1
                    _y = _y.astype(numpy.int16)

                    _t *= numpy.iinfo(numpy.int16).max - 1
                    _t = _t.astype(numpy.int16)
                    fn = e.submit(calc_pesq, _t, _y, self.stft_func.fs)
                    _fns.append(fn)

                _losses = []
                for fn in _fns:
                    v = fn.result()
                    _losses.append(v)

                loss = torch.tensor(numpy.mean(_losses))

            # For evaluation as not differentiable
            elif loss_type == 'unprocessed_pesq':
                # Use the first channel only to make faster calculation
                if xs_time is None:
                    # _ys: List[torch.Tensor]: B x [C, T, F]
                    _xs = unpad(xs, ilens, length_dim=2)
                    # ys_time: List[np.ndarray]: B x [T]
                    xs_time = [
                        self.stft_func.istft(_x[0].cpu().numpy().T)
                        for _x in _xs
                    ]
                if ts_time is None:
                    # _ts: List[torch.Tensor]: B x [C, T, F]
                    _ts = unpad(ts, ilens, length_dim=2)
                    # ts_time: List[np.ndarray]: B x [T]
                    ts_time = [
                        self.stft_func.istft(_t[0].cpu().numpy().T)
                        for _t in _ts
                    ]

                _fns = []

                # PESQ via subprocess can be parallerize by threading
                e = ThreadPoolExecutor(self.pesq_nworker)
                for _x, _t in zip(xs_time, ts_time):
                    _x = _x * numpy.iinfo(numpy.int16).max - 1
                    _x = _x.astype(numpy.int16)

                    _t = _t * numpy.iinfo(numpy.int16).max - 1
                    _t = _t.astype(numpy.int16)
                    fn = e.submit(calc_pesq, _t, _x, self.stft_func.fs)
                    _fns.append(fn)

                _losses = []
                for fn in _fns:
                    v = fn.result()
                    _losses.append(v)

                loss = torch.tensor(numpy.mean(_losses))

            elif loss_type == 'wpe_pesq':
                with torch.no_grad():
                    # (B, C, T, F) -> (B, F, C, T)
                    _xs = xs.permute(0, 3, 1, 2).contiguous()
                    # _ys: (B, F, C, T)
                    _ys = wpe(_xs, 5, 3, 3)[:, :, ref_channel]
                    _ys = unpad(_ys, ilens, length_dim=2)
                    ys_time = [
                        self.stft_func.istft(_y.cpu().numpy()) for _y in _ys
                    ]

                if ts_time is None:
                    # _ts: List[torch.Tensor]: B x [C, T, F]
                    _ts = unpad(ts, ilens, length_dim=2)
                    # ts_time: List[np.ndarray]: B x [T]
                    ts_time = [
                        self.stft_func.istft(_t[0].cpu().numpy().T)
                        for _t in _ts
                    ]

                _fns = []

                # PESQ via subprocess can be parallerize by threading
                e = ThreadPoolExecutor(self.pesq_nworker)
                for _y, _t in zip(ys_time, ts_time):
                    _y *= numpy.iinfo(numpy.int16).max - 1
                    _y = _y.astype(numpy.int16)

                    _t *= numpy.iinfo(numpy.int16).max - 1
                    _t = _t.astype(numpy.int16)
                    fn = e.submit(calc_pesq, _t, _y, self.stft_func.fs)
                    _fns.append(fn)

                _losses = []
                for fn in _fns:
                    v = fn.result()
                    _losses.append(v)

                loss = torch.tensor(numpy.mean(_losses))

            elif loss_type == 'wpe_mse':
                # Note: No updated parameters existing
                # 96328786.2853478
                if uts is None:
                    uts = FC.cat(unpad(ts, ilens, length_dim=2), dim=1)
                with torch.no_grad():
                    # (B, C, T, F) -> (B, F, C, T)
                    _xs = xs.permute(0, 3, 1, 2)
                    # _ys: (B, F, C, T) -> (B, T, F)
                    _ys = wpe(_xs, 5, 3, 3)[:, :, ref_channel].transpose(1, 2)
                    _uys = FC.cat(unpad(_ys, ilens, length_dim=1), dim=0)
                    _ys = _uys.real**2 + _uys.imag**2
                    _ts = uts.real**2 + uts.imag**2
                    _ts = _ts[ref_channel]

                    loss = mse_loss(_ys, _ts)

            else:
                raise NotImplementedError(f'loss_type={loss_type}')

            # Don't return scalar
            loss_dict[loss_type] = loss[None]

        return loss_dict