Beispiel #1
0
    def forward(
        self, xs: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """The forward function

        Args:
            xs: (B, F, C, T)
            ilens: (B,)
        Returns:
            hs (torch.Tensor): The hidden vector (B, F, C, T)
            masks: A tuple of the masks. (B, F, C, T)
            ilens: (B,)
        """
        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
        _, _, C, input_length = xs.size()
        # (B, F, C, T) -> (B, C, T, F)
        xs = xs.permute(0, 2, 3, 1)

        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
        xs = (xs.real**2 + xs.imag**2)**0.5
        # xs: (B, C, T, F) -> xs: (B * C, T, F)
        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
        # ilens: (B,) -> ilens_: (B * C)
        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)

        # xs: (B * C, T, F) -> xs: (B * C, T, D)
        xs, _, _ = self.brnn(xs, ilens_)
        # xs: (B * C, T, D) -> xs: (B, C, T, D)
        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))

        masks = []
        for linear in self.linears:
            # xs: (B, C, T, D) -> mask:(B, C, T, F)
            mask = linear(xs)

            if self.nonlinear == "sigmoid":
                mask = torch.sigmoid(mask)
            elif self.nonlinear == "relu":
                mask = torch.relu(mask)
            elif self.nonlinear == "tanh":
                mask = torch.tanh(mask)
            elif self.nonlinear == "crelu":
                mask = torch.clamp(mask, min=0, max=1)
            # Zero padding
            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)

            # (B, C, T, F) -> (B, F, C, T)
            mask = mask.permute(0, 3, 1, 2)

            # Take cares of multi gpu cases: If input_length > max(ilens)
            if mask.size(-1) < input_length:
                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
            masks.append(mask)

        return tuple(masks), ilens
Beispiel #2
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, C, T, F), double precision
            ilens: (B,)
        Returns:
            data: (B, C, T, F), double precision
            ilens: (B,)
        """
        # (B, T, C, F) -> (B, F, C, T)
        enhanced = data = data.permute(0, 3, 2, 1)
        mask = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = enhanced.real**2 + enhanced.imag**2
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                (mask, ), _ = self.mask_est(enhanced, ilens)
                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-1)[..., None]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = power * mask

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = power.mean(dim=-2)

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = wpe_one_iteration(
                data.contiguous().double(),
                power.double(),
                taps=self.taps,
                delay=self.delay,
                inverse_power=self.inverse_power,
            )
            enhanced = enhanced.type(data.dtype)
            enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = enhanced.permute(0, 3, 2, 1)
        if mask is not None:
            mask = mask.transpose(-1, -3)
        return enhanced, ilens, mask
Beispiel #3
0
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """DNN_WPE forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector

        Args:
            data: (B, T, C, F)
            ilens: (B,)
        Returns:
            enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            ilens: (B,)
            masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
            power (List[torch.Tensor]): (B, F, T)
        """
        # (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        enhanced = [data for i in range(self.nmask)]
        masks = None
        power = None

        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = [enh.real**2 + enh.imag**2 for enh in enhanced]
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                masks, _ = self.mask_est(data, ilens)
                # floor masks to increase numerical stability
                if self.mask_flooring:
                    masks = [m.clamp(min=self.flooring_thres) for m in masks]
                if self.normalization:
                    # Normalize along T
                    masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = [p * masks[i] for i, p in enumerate(power)]

            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]

            # enhanced: (..., C, T) -> (..., C, T)
            # NOTE(kamo): Calculate in double precision
            enhanced = [
                wpe_one_iteration(
                    data.contiguous().double(),
                    p.double(),
                    taps=self.taps,
                    delay=self.delay,
                    inverse_power=self.inverse_power,
                ) for p in power
            ]
            enhanced = [
                enh.to(dtype=data.dtype).masked_fill(
                    make_pad_mask(ilens, enh.real), 0) for enh in enhanced
            ]

        # (B, F, C, T) -> (B, T, C, F)
        enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
        if masks is not None:
            masks = ([m.transpose(-1, -3) for m in masks]
                     if self.nmask > 1 else masks[0].transpose(-1, -3))
        if self.nmask == 1:
            enhanced = enhanced[0]

        return enhanced, ilens, masks, power