Ejemplo n.º 1
0
    def forward(self, input: ComplexTensor, ilens: torch.Tensor):
        """Forward.

        Args:
            input (ComplexTensor): spectrum [Batch, T, (C,) F]
            ilens (torch.Tensor): input lengths [Batch]
        """
        if not isinstance(input, ComplexTensor) and (
                is_torch_1_9_plus and not torch.is_complex(input)):
            raise TypeError("Only support complex tensors for stft decoder")

        bs = input.size(0)
        if input.dim() == 4:
            multi_channel = True
            # input: (Batch, T, C, F) -> (Batch * C, T, F)
            input = input.transpose(1, 2).reshape(-1, input.size(1),
                                                  input.size(3))
        else:
            multi_channel = False

        wav, wav_lens = self.stft.inverse(input, ilens)

        if multi_channel:
            # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
            wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)

        return wav, wav_lens
Ejemplo n.º 2
0
    def forward(self,
                psd_in: ComplexTensor,
                ilens: torch.LongTensor,
                scaling: float = 2.0) -> Tuple[torch.Tensor, torch.LongTensor]:
        """The forward function

        Args:
            psd_in (ComplexTensor): (B, F, C, C)
            ilens (torch.Tensor): (B,)
            scaling (float):
        Returns:
            u (torch.Tensor): (B, C)
            ilens (torch.Tensor): (B,)
        """
        B, _, C = psd_in.size()[:3]
        assert psd_in.size(2) == psd_in.size(3), psd_in.size()
        # psd_in: (B, F, C, C)
        datatype = torch.bool if is_torch_1_2_plus else torch.uint8
        psd = psd_in.masked_fill(
            torch.eye(C, dtype=datatype, device=psd_in.device), 0)
        # psd: (B, F, C, C) -> (B, C, F)
        psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)

        # Calculate amplitude
        psd_feat = (psd.real**2 + psd.imag**2)**0.5

        # (B, C, F) -> (B, C, F2)
        mlp_psd = self.mlp_psd(psd_feat)
        # (B, C, F2) -> (B, C, 1) -> (B, C)
        e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
        u = F.softmax(scaling * e, dim=-1)
        return u, ilens
Ejemplo n.º 3
0
def trace(a: ComplexTensor) -> ComplexTensor:
    if LooseVersion(torch.__version__) >= LooseVersion('1.3'):
        datatype = torch.bool
    else:
        datatype = torch.uint8
    E = torch.eye(a.real.size(-1), dtype=datatype).expand(*a.size())
    if LooseVersion(torch.__version__) >= LooseVersion('1.1'):
        E = E.type(torch.bool)
    return a[E].view(*a.size()[:-1]).sum(-1)
Ejemplo n.º 4
0
def get_covariances(
    Y: ComplexTensor,
    inverse_power: torch.Tensor,
    bdelay: int,
    btaps: int,
    get_vector: bool = False,
) -> ComplexTensor:
    """Calculates the power normalized spatio-temporal

     covariance matrix of the framed signal.

    Args:
        Y : Complext STFT signal with shape (B, F, C, T)
        inverse_power : Weighting factor with shape (B, F, T)

    Returns:
        Correlation matrix of shape (B, F, (btaps+1) * C, (btaps+1) * C)
        Correlation vector of shape (B, F, btaps + 1, C, C)
    """
    assert inverse_power.dim() == 3, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0))

    Bs, Fdim, C, T = Y.shape

    # (B, F, C, T - bdelay - btaps + 1, btaps + 1)
    Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[
        ..., : T - bdelay - btaps + 1, :
    ]
    # Reverse along btaps-axis:
    # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1]
    Psi = FC.reverse(Psi, dim=-1)
    Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None]

    # let T' = T - bdelay - btaps + 1
    # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1)
    #  -> (B, F, btaps + 1, C, btaps + 1, C)
    covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj()))

    # (B, F, btaps + 1, C, btaps + 1, C)
    #   -> (B, F, (btaps + 1) * C, (btaps + 1) * C)
    covariance_matrix = covariance_matrix.view(
        Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C
    )

    if get_vector:
        # (B, F, C, T', btaps + 1) x (B, F, C, T')
        #    --> (B, F, btaps +1, C, C)
        covariance_vector = FC.einsum(
            "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj())
        )
        return covariance_matrix, covariance_vector
    else:
        return covariance_matrix
Ejemplo n.º 5
0
def get_mvdr_vector(psd_s: ComplexTensor,
                    psd_n: ComplexTensor,
                    reference_vector: torch.Tensor,
                    eps: float = 1e-15) -> ComplexTensor:
    """Return the MVDR(Minimum Variance Distortionless Response) vector:

        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u

    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420

    Args:
        psd_s (ComplexTensor): (..., F, C, C)
        psd_n (ComplexTensor): (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    # Add eps
    C = psd_n.size(-1)
    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
    eye = eye.view(*shape)
    psd_n += eps * eye

    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum('...ec,...cd->...ed', [psd_n.inverse(), psd_s])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum('...fec,...c->...fe', [ws, reference_vector])
    return beamform_vector
Ejemplo n.º 6
0
    def forward(
        self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # (B, T, F) or (B, T, C, F)
        if x.dim() not in (3, 4):
            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
        if not torch.is_tensor(ilens):
            ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)

        if x.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(x.size(2))
                h = x[:, :, ch, :]
            else:
                # Use the first channel
                h = x[:, :, 0, :]
        else:
            h = x

        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        h = h.real ** 2 + h.imag ** 2

        h, _ = self.logmel(h, ilens)
        if self.stats_file is not None:
            h, _ = self.global_mvn(h, ilens)
        if self.apply_uttmvn:
            h, _ = self.uttmvn(h, ilens)

        return h, ilens
Ejemplo n.º 7
0
def test_complex_norm(dim):
    mat = ComplexTensor(torch.rand(2, 3, 4), torch.rand(2, 3, 4))
    mat_th = torch.complex(mat.real, mat.imag)
    norm = complex_norm(mat, dim=dim, keepdim=True)
    norm_th = complex_norm(mat_th, dim=dim, keepdim=True)
    assert (torch.allclose(norm, norm_th) and norm.ndim == mat.ndim
            and mat.numel() == norm.numel() * mat.size(dim))
Ejemplo n.º 8
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Ejemplo n.º 9
0
    def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
            -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)

        """
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)

        # mask: (B, F, C, T)
        (mask_speech, mask_noise), _ = self.mask(data, ilens)

        psd_speech = get_power_spectral_density_matrix(data, mask_speech)
        psd_noise = get_power_spectral_density_matrix(data, mask_noise)

        # u: (B, C)
        if self.ref_channel < 0:
            u, _ = self.ref(psd_speech, ilens)
        else:
            # (optional) Create onehot vector for fixed reference microphone
            u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)),
                            device=data.device)
            u[..., self.ref_channel].fill_(1)

        ws = get_mvdr_vector(psd_speech, psd_noise, u)
        enhanced = apply_beamforming_vector(ws, data)

        # (..., F, T) -> (..., T, F)
        enhanced = enhanced.transpose(-1, -2)
        mask_speech = mask_speech.transpose(-1, -3)

        return enhanced, ilens, mask_speech
Ejemplo n.º 10
0
def complex_matrix2real_matrix(c: ComplexTensor) -> torch.Tensor:
    # NOTE(kamo):
    # Complex value can be expressed as follows
    #   a + bi => a * x + b y
    # where
    #   x = |1 0|  y = |0 -1|
    #       |0 1|,     |1  0|
    # A complex matrix can be also expressed as
    #   |A -B|
    #   |B  A|
    # and complex vector can be expressed as
    #   |A|
    #   |B|
    assert c.size(-2) == c.size(-1), c.size()
    # (∗, m, m) -> (*, 2m, 2m)
    return torch.cat(
        [torch.cat([c.real, -c.imag], dim=-1), torch.cat([c.imag, c.real], dim=-1)],
        dim=-2,
    )
Ejemplo n.º 11
0
    def forward(
        self, xs: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """The forward function

        Args:
            xs: (B, F, C, T)
            ilens: (B,)
        Returns:
            hs (torch.Tensor): The hidden vector (B, F, C, T)
            masks: A tuple of the masks. (B, F, C, T)
            ilens: (B,)
        """
        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
        _, _, C, input_length = xs.size()
        # (B, F, C, T) -> (B, C, T, F)
        xs = xs.permute(0, 2, 3, 1)

        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
        xs = (xs.real**2 + xs.imag**2)**0.5
        # xs: (B, C, T, F) -> xs: (B * C, T, F)
        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
        # ilens: (B,) -> ilens_: (B * C)
        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)

        # xs: (B * C, T, F) -> xs: (B * C, T, D)
        xs, _, _ = self.brnn(xs, ilens_)
        # xs: (B * C, T, D) -> xs: (B, C, T, D)
        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))

        masks = []
        for linear in self.linears:
            # xs: (B, C, T, D) -> mask:(B, C, T, F)
            mask = linear(xs)

            if self.nonlinear == "sigmoid":
                mask = torch.sigmoid(mask)
            elif self.nonlinear == "relu":
                mask = torch.relu(mask)
            elif self.nonlinear == "tanh":
                mask = torch.tanh(mask)
            elif self.nonlinear == "crelu":
                mask = torch.clamp(mask, min=0, max=1)
            # Zero padding
            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)

            # (B, C, T, F) -> (B, F, C, T)
            mask = mask.permute(0, 3, 1, 2)

            # Take cares of multi gpu cases: If input_length > max(ilens)
            if mask.size(-1) < input_length:
                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
            masks.append(mask)

        return tuple(masks), ilens
Ejemplo n.º 12
0
def get_correlations(Y: ComplexTensor, inverse_power: torch.Tensor, taps,
                     delay) -> Tuple[ComplexTensor, ComplexTensor]:
    """Calculates weighted correlations of a window of length taps

    Args:
        Y : Complex-valued STFT signal with shape (F, C, T)
        inverse_power : Weighting factor with shape (F, T)
        taps (int): Lenghts of correlation window
        delay (int): Delay for the weighting factor

    Returns:
        Correlation matrix of shape (F, taps*C, taps*C)
        Correlation vector of shape (F, taps, C, C)
    """
    assert inverse_power.dim() == 2, inverse_power.dim()
    assert inverse_power.size(0) == Y.size(0), \
        (inverse_power.size(0), Y.size(0))

    F, C, T = Y.size()

    # Y: (F, C, T) -> Psi: (F, C, T, taps)
    Psi = signal_framing(Y, frame_length=taps,
                         frame_step=1)[..., :T - delay - taps + 1, :]
    # Reverse along taps-axis
    Psi = FC.reverse(Psi, dim=-1)
    Psi_conj_norm = \
        Psi.conj() * inverse_power[..., None, delay + taps - 1:, None]

    # (F, C, T, taps) x (F, C, T, taps) -> (F, taps, C, taps, C)
    correlation_matrix = FC.einsum('fdtk,fetl->fkdle', (Psi_conj_norm, Psi))
    # (F, taps, C, taps, C) -> (F, taps * C, taps * C)
    correlation_matrix = correlation_matrix.view(F, taps * C, taps * C)

    # (F, C, T, taps) x (F, C, T) -> (F, taps, C, C)
    correlation_vector = FC.einsum('fdtk,fet->fked',
                                   (Psi_conj_norm, Y[..., delay + taps - 1:]))

    return correlation_matrix, correlation_vector
Ejemplo n.º 13
0
def get_filter_matrix_conj(correlation_matrix: ComplexTensor,
                           correlation_vector: ComplexTensor,
                           eps: float = 1e-10) -> ComplexTensor:
    """Calculate (conjugate) filter matrix based on correlations for one freq.

    Args:
        correlation_matrix : Correlation matrix (F, taps * C, taps * C)
        correlation_vector : Correlation vector (F, taps, C, C)
        eps:

    Returns:
        filter_matrix_conj (ComplexTensor): (F, taps, C, C)
    """
    F, taps, C, _ = correlation_vector.size()

    # (F, taps, C1, C2) -> (F, C1, taps, C2) -> (F, C1, taps * C2)
    correlation_vector = \
        correlation_vector.permute(0, 2, 1, 3)\
        .contiguous().view(F, C, taps * C)

    eye = torch.eye(correlation_matrix.size(-1),
                    dtype=correlation_matrix.dtype,
                    device=correlation_matrix.device)
    shape = tuple(1 for _ in range(correlation_matrix.dim() - 2)) + \
        correlation_matrix.shape[-2:]
    eye = eye.view(*shape)
    correlation_matrix += eps * eye

    inv_correlation_matrix = correlation_matrix.inverse()
    # (F, C, taps, C) x (F, taps * C, taps * C) -> (F, C, taps * C)
    stacked_filter_conj = FC.matmul(correlation_vector,
                                    inv_correlation_matrix.transpose(-1, -2))

    # (F, C1, taps * C2) -> (F, C1, taps, C2) -> (F, taps, C2, C1)
    filter_matrix_conj = \
        stacked_filter_conj.view(F, C, taps, C).permute(0, 2, 3, 1)
    return filter_matrix_conj
Ejemplo n.º 14
0
def wpe_one_iteration(Y: ComplexTensor,
                      power: torch.Tensor,
                      taps: int = 10,
                      delay: int = 3,
                      eps: float = 1e-10,
                      inverse_power: bool = True) -> ComplexTensor:
    """WPE for one iteration

    Args:
        Y: Complex valued STFT signal with shape (..., C, T)
        power: : (..., T)
        taps: Number of filter taps
        delay: Delay as a guard interval, such that X does not become zero.
        eps:
        inverse_power (bool):
    Returns:
        enhanced: (..., C, T)
    """
    assert Y.size()[:-2] == power.size()[:-1]
    batch_freq_size = Y.size()[:-2]
    Y = Y.view(-1, *Y.size()[-2:])
    power = power.view(-1, power.size()[-1])

    if inverse_power:
        inverse_power = 1 / torch.clamp(power, min=eps)
    else:
        inverse_power = power

    correlation_matrix, correlation_vector = \
        get_correlations(Y, inverse_power, taps, delay)
    filter_matrix_conj = get_filter_matrix_conj(correlation_matrix,
                                                correlation_vector)
    enhanced = perform_filter_operation_v2(Y, filter_matrix_conj, taps, delay)

    enhanced = enhanced.view(*batch_freq_size, *Y.size()[-2:])
    return enhanced
Ejemplo n.º 15
0
def perform_filter_operation_v2(Y: ComplexTensor,
                                filter_matrix_conj: ComplexTensor,
                                taps, delay) -> ComplexTensor:
    """perform_filter_operation_v2

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    # Y_tilde: (taps, F, C, T)
    Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0),
                               mode='constant', value=0)
                        for i in range(taps)],
                       dim=0)
    reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde))
    return Y - reverb_tail
Ejemplo n.º 16
0
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        input_stft, feats_lens = self.stft(input, input_lengths)

        assert input_stft.dim() >= 4, input_stft.shape
        # "2" refers to the real/imag parts of Complex
        assert input_stft.shape[-1] == 2, input_stft.shape

        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])

        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(input_stft.size(2))
                input_stft = input_stft[:, :, ch, :]
            else:
                # Use the first channel
                input_stft = input_stft[:, :, 0, :]

        # 4. STFT -> Power spectrum
        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        input_power = input_stft.real ** 2 + input_stft.imag ** 2

        # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
        # input_power: (Batch, [Channel,] Length, Freq)
        #       -> input_feats: (Batch, Length, Dim)
        input_feats, _ = self.logmel(input_power, feats_lens)

        return input_feats, feats_lens
Ejemplo n.º 17
0
    def forward(self, xs: ComplexTensor, input_lengths: torch.LongTensor) \
            -> torch.Tensor:
        assert xs.size(0) == input_lengths.size(0), (xs.size(0),
                                                     input_lengths.size(0))

        # xs: (B, C, T, D)
        C = xs.size(1)
        if self.feat_type == 'amplitude':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5
        elif self.feat_type == 'power':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = xs.real ** 2 + xs.imag ** 2
        elif self.feat_type == 'log_power':
            # xs: (B, C, T, F) -> (B, C, T, F)
            xs = torch.log(xs.real ** 2 + xs.imag ** 2)
        elif self.feat_type == 'concat':
            # xs: (B, C, T, F) -> (B, C, T, 2 * F)
            xs = torch.cat([xs.real, xs.imag], -1)
        else:
            raise NotImplementedError(f'Not implemented: {self.feat_type}')

        if self.model_type in ('blstm', 'lstm'):
            # xs: (B, C, T, F) -> xs: (B, C, T, D)
            xs = self.net(xs, input_lengths)

        elif self.model_type == 'cnn':
            if self.channel_independent:
                # xs: (B, C, T, F) -> xs: (B * C, F, T)
                xs = xs.view(-1, *xs.size()[2:]).transpose(1, 2)
                # xs: (B * C, F, T) -> xs: (B * C, D, T)
                xs = self.net(xs)
                # xs: (B * C, D, T) -> (B, C, T, D)
                xs = xs.transpose(1, 2).contiguous().view(
                    -1, C, xs.size(2), xs.size(1))
            else:
                # xs: (B, C, T, F) -> xs: (B, C, T, F)
                xs = self.net(xs)
        else:
            raise NotImplementedError(f'Not implemented: {self.model_type}')

        # xs: (B, C, T, D) -> out:(B, C, T, F)
        out = self.linear(xs)
        # Zero padding
        out = torch.sigmoid(out)
        out.masked_fill(make_pad_mask(input_lengths, out, length_dim=2), 0)

        return out
Ejemplo n.º 18
0
def perform_filter_operation(Y: ComplexTensor,
                             filter_matrix_conj: ComplexTensor, taps, delay) \
        -> ComplexTensor:
    """perform_filter_operation

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    reverb_tail = ComplexTensor(torch.zeros_like(Y.real),
                                torch.zeros_like(Y.real))
    for tau_minus_delay in range(taps):
        new = FC.einsum('fde,fdt->fet',
                        (filter_matrix_conj[:, tau_minus_delay, :, :],
                         Y[:, :, :T - delay - tau_minus_delay]))
        new = FC.pad(new, (delay + tau_minus_delay, 0),
                     mode='constant', value=0)
        reverb_tail = reverb_tail + new

    return Y - reverb_tail
Ejemplo n.º 19
0
def tik_reg(mat: ComplexTensor,
            reg: float = 1e-8,
            eps: float = 1e-8) -> ComplexTensor:
    """Perform Tikhonov regularization (only modifying real part).

    Args:
        mat (ComplexTensor): input matrix (..., C, C)
        reg (float): regularization factor
        eps (float)
    Returns:
        ret (ComplexTensor): regularized matrix (..., C, C)
    """
    # Add eps
    C = mat.size(-1)
    eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
    shape = [1 for _ in range(mat.dim() - 2)] + [C, C]
    eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1)
    with torch.no_grad():
        epsilon = FC.trace(mat).real[..., None, None] * reg
        # in case that correlation_matrix is all-zero
        epsilon = epsilon + eps
    mat = mat + epsilon * eye
    return mat
Ejemplo n.º 20
0
def trace(a: ComplexTensor) -> ComplexTensor:
    E = torch.eye(a.real.size(-1), dtype=torch.uint8).expand(*a.size())
    return a[E].view(*a.size()[:-1]).sum(-1)
Ejemplo n.º 21
0
    def forward(self,
                data: ComplexTensor, ilens: torch.LongTensor=None,
                return_wpe: bool=True) -> Tuple[Optional[ComplexTensor],
                                                torch.Tensor]:
        if ilens is None:
            ilens = torch.full((data.size(0),), data.size(2),
                               dtype=torch.long, device=data.device)

        r = -self.rcontext if self.rcontext != 0 else None
        enhanced = data[:, :, self.lcontext:r, :]

        if self.lcontext != 0 or self.rcontext != 0:
            assert all(ilens[0] == i for i in ilens)

            # Create context window (a.k.a Splicing)
            if self.model_type in ('blstm', 'lstm'):
                width = data.size(2) - self.lcontext - self.rcontext
                # data: (B, C, l + w + r, F)
                indices = [i + j for i in range(width)
                           for j in range(1 + self.lcontext + self.rcontext)]
                _y = data[:, :, indices]
                # data: (B, C, l, (1 + w + r), F)
                data = _y.view(
                    data.size(0), data.size(1),
                    width, (1 + self.lcontext + self.rcontext) * data.size(3))
                ilens = torch.full((data.size(0),), width,
                                   dtype=torch.long, device=data.device)
                del _y

        for i in range(self.iterations):
            power = enhanced.real ** 2 + enhanced.imag ** 2
            # Calculate power: (B, C, T, Context, F)
            if i == 0 and self.use_dnn:
                # mask: (B, C, T, F)
                mask = self.estimator(data, ilens)
                if mask.size(2) != power.size(2):
                    assert mask.size(2) == (power.size(2) + self.rcontext + self.lcontext)
                    r = -self.rcontext if self.rcontext != 0 else None
                    mask = mask[:, :, self.lcontext:r, :]

                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-2)[..., None]
                if self.out_type == 'mask':
                    power = power * mask
                else:
                    power = mask

                    if self.out_type == 'amplitude':
                        power = power ** 2
                    elif self.out_type == 'log_power':
                        power = power.exp()
                    elif self.out_type == 'power':
                        pass
                    else:
                        raise NotImplementedError(self.out_type)

            if not return_wpe:
                return None, power

            # power: (B, C, T, F) -> _power: (B, F, T)
            _power = power.mean(dim=1).transpose(-1, -2).contiguous()

            # data: (B, C, T, F) -> _data: (B, F, C, T)
            _data = data.permute(0, 3, 1, 2).contiguous()
            # _enhanced: (B, F, C, T)
            _enhanced_real = []
            _enhanced_imag = []
            for d, p, l in zip(_data, _power, ilens):
                # e: (F, C, T) -> (T, C, F)
                e = wpe_one_iteration(
                    d[..., :l], p[..., :l],
                    taps=self.taps, delay=self.delay,
                    inverse_power=self.inverse_power).transpose(0, 2)
                _enhanced_real.append(e.real)
                _enhanced_imag.append(e.imag)
            # _enhanced: B x (T, C, F) -> (B, T, C, F) -> (B, F, C, T)
            _enhanced_real = pad_sequence(_enhanced_real,
                                          batch_first=True).transpose(1, 3)
            _enhanced_imag = pad_sequence(_enhanced_imag,
                                          batch_first=True).transpose(1, 3)
            _enhanced = ComplexTensor(_enhanced_real, _enhanced_imag)

            # enhanced: (B, F, C, T) -> (B, C, T, F)
            enhanced = _enhanced.permute(0, 2, 3, 1)

        # enhanced: (B, C, T, F), power: (B, C, T, F)
        return enhanced, power
Ejemplo n.º 22
0
    # (..., C, T) * (..., C, T) -> (..., C, T)
    power = power * mask_speech
    # Averaging along the channel axis: (B, F, C, T) -> (B, F, T)
    power = power.mean(dim=-2)
    # (B, F, T) --> (B * F, T)
    power = power.view(-1, power.shape[-1])
    inverse_power = 1 / torch.clamp(power, min=eps)

    B, Fdim, C, T = Z.shape

    # covariance matrix: (B, F, (btaps+1) * C, (btaps+1) * C)
    covariance_matrix = get_covariances(
        Z, inverse_power, bdelay, btaps, get_vector=False
    )

    # speech signal PSD: (B, F, C, C)
    psd_speech = beamformer.get_power_spectral_density_matrix(
        Z, mask_speech, btaps, normalization=True
    )

    # reference vector: (B, C)
    ref_channel = 0
    u = torch.zeros(*(Z.size()[:-3] + (Z.size(-2),)), device=Z.device)
    u[..., ref_channel].fill_(1)

    # (B, F, (btaps + 1) * C)
    WPD_filter = get_WPD_filter_v2(psd_speech, covariance_matrix, u)

    # (B, F, T)
    enhanced = perform_WPD_filtering(WPD_filter, Z, bdelay, btaps)
Ejemplo n.º 23
0
def online_wpe_step(input_buffer: ComplexTensor,
                    power: torch.Tensor,
                    inv_cov: ComplexTensor = None,
                    filter_taps: ComplexTensor = None,
                    alpha: float = 0.99,
                    taps: int = 10,
                    delay: int = 3):
    """One step of online dereverberation.

    Args:
        input_buffer: (F, C, taps + delay + 1)
        power: Estimate for the current PSD (F, T)
        inv_cov: Current estimate of R^-1
        filter_taps: Current estimate of filter taps (F, taps * C, taps)
        alpha (float): Smoothing factor
        taps (int): Number of filter taps
        delay (int): Delay in frames

    Returns:
        Dereverberated frame of shape (F, D)
        Updated estimate of R^-1
        Updated estimate of the filter taps


    >>> frame_length = 512
    >>> frame_shift = 128
    >>> taps = 6
    >>> delay = 3
    >>> alpha = 0.999
    >>> frequency_bins = frame_length // 2 + 1
    >>> Q = None
    >>> G = None
    >>> unreverbed, Q, G = online_wpe_step(stft, get_power_online(stft), Q, G,
    ...                                    alpha=alpha, taps=taps, delay=delay)

    """
    assert input_buffer.size(-1) == taps + delay + 1, input_buffer.size()
    C = input_buffer.size(-2)

    if inv_cov is None:
        inv_cov = ComplexTensor(
            torch.eye(C * taps, dtype=input_buffer.dtype).expand(
                *input_buffer.size()[:-2], C * taps, C * taps))
    if filter_taps is None:
        filter_taps = ComplexTensor(
            torch.zeros(*input_buffer.size()[:-2],
                        C * taps,
                        C,
                        dtype=input_buffer.dtype))

    window = FC.reverse(input_buffer[..., :-delay - 1], dim=-1)
    # (..., C, T) -> (..., C * T)
    window = window.view(*input_buffer.size()[:-2], -1)
    pred = input_buffer[..., -1] - FC.einsum('...id,...i->...d',
                                             (filter_taps.conj(), window))

    nominator = FC.einsum('...ij,...j->...i', (inv_cov, window))
    denominator = \
        FC.einsum('...i,...i->...', (window.conj(), nominator)) + alpha * power
    kalman_gain = nominator / denominator[..., None]

    inv_cov_k = inv_cov - FC.einsum('...j,...jm,...i->...im',
                                    (window.conj(), inv_cov, kalman_gain))
    inv_cov_k /= alpha

    filter_taps_k = \
        filter_taps + FC.einsum('...i,...m->...im', (kalman_gain, pred.conj()))
    return pred, inv_cov_k, filter_taps_k
Ejemplo n.º 24
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (ComplexTensor): stacked observation covariance matrix
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector