예제 #1
0
    def __iter__(self):
        _ks = []
        _xs = []
        _ts = []
        _ilens = []

        for ks, xs, ts, ilens in self.dataloader:
            assert len(ks) == 1, \
                f'batch-size of dataloder is not 1: {len(ks)} != 1'
            # Check shape:
            assert isinstance(ks[0], str), type(ks[0])
            # Expected: x: ( C, T, F), t: (C, T, F)
            assert xs[0].dim() == 3, xs[0].shape
            assert xs[0].shape == ts[0].shape, (xs[0].shape, ts[0].shape)

            offset = 0
            while True:
                ilen = ilens[0]
                if offset + self.width > ilen:
                    break

                _k = ks[0]
                _x = xs[0, :,
                        max(offset - self.lcontext, 0):offset + self.width + self.rcontext, :]
                if _x.shape[1] < self.width + self.lcontext + self.rcontext:
                    lp = max(0, self.lcontext - offset)
                    rp = max(offset + self.width + self.rcontext - xs.size(2), 0)
                    _x = FC.pad(_x, (0, 0, rp, lp, 0, 0), mode='constant')

                # _t: (C, width, F)
                _t = ts[0, :, offset:offset + self.width, :]
                _l = self.width + self.lcontext + self.rcontext

                _ks.append(_k)
                _xs.append(_x)
                _ts.append(_t)
                _ilens.append(_l)
                offset += self.width

                if len(_ks) == self.batch_size:
                    _ks = tuple(_ks)
                    # _x: (C, width, context, F)
                    yield _ks, FC.stack(_xs), \
                        FC.stack(_ts), torch.tensor(_ilens, device=_xs[0].device)

                    _ks = []
                    _xs = []
                    _ts = []
                    _ilens = []
예제 #2
0
def get_adjacent(spec: ComplexTensor, filter_length: int = 5) -> ComplexTensor:
    """Zero-pad and unfold stft, i.e.,

    add zeros to the beginning so that, using the multi-frame signal model,
    there will be as many output frames as input frames.

    Args:
        spec (ComplexTensor): input spectrum (B, F, T)
        filter_length (int): length for frame extension
    Returns:
        ret (ComplexTensor): output spectrum (B, F, T, filter_length)
    """  # noqa: D400
    return (FC.pad(spec,
                   pad=[filter_length - 1, 0]).unfold(dim=-1,
                                                      size=filter_length,
                                                      step=1).contiguous())
예제 #3
0
def perform_filter_operation_v2(Y: ComplexTensor,
                                filter_matrix_conj: ComplexTensor,
                                taps, delay) -> ComplexTensor:
    """perform_filter_operation_v2

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    # Y_tilde: (taps, F, C, T)
    Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0),
                               mode='constant', value=0)
                        for i in range(taps)],
                       dim=0)
    reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde))
    return Y - reverb_tail
예제 #4
0
def perform_filter_operation(Y: ComplexTensor,
                             filter_matrix_conj: ComplexTensor, taps, delay) \
        -> ComplexTensor:
    """perform_filter_operation

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    reverb_tail = ComplexTensor(torch.zeros_like(Y.real),
                                torch.zeros_like(Y.real))
    for tau_minus_delay in range(taps):
        new = FC.einsum('fde,fdt->fet',
                        (filter_matrix_conj[:, tau_minus_delay, :, :],
                         Y[:, :, :T - delay - tau_minus_delay]))
        new = FC.pad(new, (delay + tau_minus_delay, 0),
                     mode='constant', value=0)
        reverb_tail = reverb_tail + new

    return Y - reverb_tail
예제 #5
0
def signal_framing(
    signal: Union[torch.Tensor, ComplexTensor],
    frame_length: int,
    frame_step: int,
    bdelay: int,
    do_padding: bool = False,
    pad_value: int = 0,
    indices: List = None,
) -> Union[torch.Tensor, ComplexTensor]:
    """Expand `signal` into several frames, with each frame of length `frame_length`.

    Args:
        signal : (..., T)
        frame_length:   length of each segment
        frame_step:     step for selecting frames
        bdelay:         delay for WPD
        do_padding:     whether or not to pad the input signal at the beginning
                          of the time dimension
        pad_value:      value to fill in the padding

    Returns:
        torch.Tensor:
            if do_padding: (..., T, frame_length)
            else:          (..., T - bdelay - frame_length + 2, frame_length)
    """
    if indices is None:
        frame_length2 = frame_length - 1
        # pad to the right at the last dimension of `signal` (time dimension)
        if do_padding:
            # (..., T) --> (..., T + bdelay + frame_length - 2)
            signal = FC.pad(
                signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value
            )

        # indices:
        # [[ 0, 1, ..., frame_length2 - 1,              frame_length2 - 1 + bdelay ],
        #  [ 1, 2, ..., frame_length2,                  frame_length2 + bdelay     ],
        #  [ 2, 3, ..., frame_length2 + 1,              frame_length2 + 1 + bdelay ],
        #  ...
        #  [ T-bdelay-frame_length2, ..., T-1-bdelay,   T-1 ]
        indices = [
            [*range(i, i + frame_length2), i + frame_length2 + bdelay - 1]
            for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step)
        ]

    if isinstance(signal, ComplexTensor):
        real = signal_framing(
            signal.real,
            frame_length,
            frame_step,
            bdelay,
            do_padding,
            pad_value,
            indices,
        )
        imag = signal_framing(
            signal.imag,
            frame_length,
            frame_step,
            bdelay,
            do_padding,
            pad_value,
            indices,
        )
        return ComplexTensor(real, imag)
    else:
        # (..., T - bdelay - frame_length + 2, frame_length)
        signal = signal[..., indices]
        # signal[..., :-1] = -signal[..., :-1]
        return signal
예제 #6
0
def get_WPD_filter_with_rtf(
    psd_observed_bar: ComplexTensor,
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    iterations: int = 3,
    reference_vector: Union[int, torch.Tensor, None] = None,
    normalize_ref_channel: Optional[int] = None,
    use_torch_solver: bool = True,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the WPD vector calculated with RTF.

        WPD is the Weighted Power minimization Distortionless response
        convolutional beamformer. As follows:

        h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)

    Reference:
        T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
        for Simultaneous Denoising and Dereverberation," in IEEE Signal
        Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
        10.1109/LSP.2019.2911179.
        https://ieeexplore.ieee.org/document/8691481

    Args:
        psd_observed_bar (ComplexTensor): stacked observation covariance matrix
        psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
        psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
        iterations (int): number of iterations in power method
        reference_vector (torch.Tensor or int): (..., C) or scalar
        normalize_ref_channel (int): reference channel for normalizing the RTF
        use_torch_solver (bool): Whether to use `solve` instead of `inverse`
        diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
        diag_eps (float):
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    C = psd_noise.size(-1)
    if diagonal_loading:
        psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)

    # (B, F, C, 1)
    rtf = get_rtf(
        psd_speech,
        psd_noise,
        reference_vector,
        iterations=iterations,
        use_torch_solver=use_torch_solver,
    )

    # (B, F, (K+1)*C, 1)
    rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
    # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
    if use_torch_solver:
        numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
    else:
        numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
    denominator = FC.einsum("...d,...d->...",
                            [rtf.squeeze(-1).conj(), numerator])
    if normalize_ref_channel is not None:
        scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
        beamforming_vector = numerator * scale / (
            denominator.real.unsqueeze(-1) + eps)
    else:
        beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
    return beamforming_vector