Beispiel #1
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, C, T, F), double precision
            ilens: (B,)
        Returns:
            data: (B, C, T, F), double precision
            ilens: (B,)
        """
        # (B, T, C, F) -> (B, F, C, T)
        enhanced = data = data.permute(0, 3, 2, 1)
        mask = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = enhanced.real**2 + enhanced.imag**2
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                (mask, ), _ = self.mask_est(enhanced, ilens)
                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-1)[..., None]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = power * mask

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = power.mean(dim=-2)

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = wpe_one_iteration(
                data.contiguous().double(),
                power.double(),
                taps=self.taps,
                delay=self.delay,
                inverse_power=self.inverse_power,
            )
            enhanced = enhanced.type(data.dtype)
            enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = enhanced.permute(0, 3, 2, 1)
        if mask is not None:
            mask = mask.transpose(-1, -3)
        return enhanced, ilens, mask
Beispiel #2
0
    def forward(self,
                data: ComplexTensor, ilens: torch.LongTensor=None,
                return_wpe: bool=True) -> Tuple[Optional[ComplexTensor],
                                                torch.Tensor]:
        if ilens is None:
            ilens = torch.full((data.size(0),), data.size(2),
                               dtype=torch.long, device=data.device)

        r = -self.rcontext if self.rcontext != 0 else None
        enhanced = data[:, :, self.lcontext:r, :]

        if self.lcontext != 0 or self.rcontext != 0:
            assert all(ilens[0] == i for i in ilens)

            # Create context window (a.k.a Splicing)
            if self.model_type in ('blstm', 'lstm'):
                width = data.size(2) - self.lcontext - self.rcontext
                # data: (B, C, l + w + r, F)
                indices = [i + j for i in range(width)
                           for j in range(1 + self.lcontext + self.rcontext)]
                _y = data[:, :, indices]
                # data: (B, C, l, (1 + w + r), F)
                data = _y.view(
                    data.size(0), data.size(1),
                    width, (1 + self.lcontext + self.rcontext) * data.size(3))
                ilens = torch.full((data.size(0),), width,
                                   dtype=torch.long, device=data.device)
                del _y

        for i in range(self.iterations):
            power = enhanced.real ** 2 + enhanced.imag ** 2
            # Calculate power: (B, C, T, Context, F)
            if i == 0 and self.use_dnn:
                # mask: (B, C, T, F)
                mask = self.estimator(data, ilens)
                if mask.size(2) != power.size(2):
                    assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext)
                    r = -self.rcontext if self.rcontext != 0 else None
                    mask = mask[:, :, self.lcontext:r, :]

                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-2)[..., None]
                if self.out_type == 'mask':
                    power = power * mask
                else:
                    power = mask

                    if self.out_type == 'amplitude':
                        power = power ** 2
                    elif self.out_type == 'log_power':
                        power = power.exp()
                    elif self.out_type == 'power':
                        pass
                    else:
                        raise NotImplementedError(self.out_type)

            if not return_wpe:
                return None, power

            # power: (B, C, T, F) -> _power: (B, F, T)
            _power = power.mean(dim=1).transpose(-1, -2).contiguous()

            # data: (B, C, T, F) -> _data: (B, F, C, T)
            _data = data.permute(0, 3, 1, 2).contiguous()
            # _enhanced: (B, F, C, T)
            _enhanced_real = []
            _enhanced_imag = []
            for d, p, l in zip(_data, _power, ilens):
                # e: (F, C, T) -> (T, C, F)
                e = wpe_one_iteration(
                    d[..., :l], p[..., :l],
                    taps=self.taps, delay=self.delay,
                    inverse_power=self.inverse_power).transpose(0, 2)
                _enhanced_real.append(e.real)
                _enhanced_imag.append(e.imag)
            # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T)
            _enhanced_real = pad_sequence(_enhanced_real,
                                          batch_first=True).transpose(1, 3)
            _enhanced_imag = pad_sequence(_enhanced_imag,
                                          batch_first=True).transpose(1, 3)
            _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag)

            # enhanced: (B, F, C, T) -> (B, C, T, F)
            enhanced = _enhanced.permute(0, 2, 3, 1)

        # enhanced: (B, C, T, F), power: (B, C, T, F)
        return enhanced, power
Beispiel #3
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """DNN_WPE forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, T, C, F)
            ilens: (B,)
        Returns:
            enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            ilens: (B,)
            masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            power (List[torch.Tensor]): (B, F, T)
        """
        # (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        enhanced = [data for i in range(self.nmask)]
        masks = None
        power = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = [enh.real**2 + enh.imag**2 for enh in enhanced]
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                masks, _ = self.mask_est(data, ilens)
                # floor masks to increase numerical stability
                if self.mask_flooring:
                    masks = [m.clamp(min=self.flooring_thres) for m in masks]
                if self.normalization:
                    # Normalize along T
                    masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = [p * masks[i] for i, p in enumerate(power)]

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = [
                wpe_one_iteration(
                    data.contiguous().double(),
                    p.double(),
                    taps=self.taps,
                    delay=self.delay,
                    inverse_power=self.inverse_power,
                ) for p in power
            ]
            enhanced = [
                enh.to(dtype=data.dtype).masked_fill(
                    make_pad_mask(ilens, enh.real), 0) for enh in enhanced
            ]

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
        if masks is not None:
            masks = ([m.transpose(-1, -3) for m in masks]
                     if self.nmask > 1 else masks[0].transpose(-1, -3))
        if self.nmask == 1:
            enhanced = enhanced[0]

        return enhanced, ilens, masks, power