コード例 #1
0
ファイル: functional.py プロジェクト: tmasquelier/SpykeTorch
def threshold_(potentials, threshold=None):
	r"""The inplace version of :func:`~threshold`
	"""
	if threshold is None:
		potentials[:-1]=0
	else:
		fn.threshold_(potentials, threshold, 0)
コード例 #2
0
ファイル: functional.py プロジェクト: tmasquelier/SpykeTorch
def fire_(potentials, threshold=None):
	r"""The inplace version of :func:`~fire`
	"""
	if threshold is None:
		potentials[:-1]=0
	else:
		fn.threshold_(potentials, threshold, 0)
	potentials.sign_()
コード例 #3
0
def threshold(potentials, threshold=None):
    r"""Applies a threshold on potentials by which all of the values lower or equal to the threshold becomes zero.
    If :attr:`threshold` is :attr:`None`, only the potentials corresponding to the final time step will survive.
    Args:
            potentials (Tensor): The tensor of input potentials.
            threshold (float): The threshold value. Default: None
    Returns:
            Tensor: Thresholded potentials.
    """
    outputs = potentials.clone().detach()
    if threshold is None:
        outputs[:-1] = 0
    else:
        fn.threshold_(outputs, threshold, 0)
    return outputs
コード例 #4
0
def fire(potentials, threshold=None, return_thresholded_potentials=False):
    r"""Computes the spike-wave tensor from tensor of potentials. If :attr:`threshold` is :attr:`None`, all the neurons
    emit one spike (if the potential is greater than zero) in the last time step.
    Args:
            potentials (Tensor): The tensor of input potentials.
            threshold (float): Firing threshold. Default: None
            return_thresholded_potentials (boolean): If True, the tensor of thresholded potentials will be returned
            as well as the tensor of spike-wave. Default: False
    Returns:
            Tensor: Spike-wave tensor.
    """
    thresholded = potentials.clone().detach()
    if threshold is None:
        thresholded[:-1] = 0
    else:
        fn.threshold_(thresholded, threshold, 0)
    if return_thresholded_potentials:
        return thresholded.sign(), thresholded
    return thresholded.sign()
コード例 #5
0
 def test_threshold_(self):
     inp = torch.randn(1, 8, 32, 32, device='cuda', dtype=self.dtype)
     output = F.threshold_(inp, 6, 6)
コード例 #6
0
def ADMM(spec,
         maxiter=1000,
         tol=1e-6,
         rho=0.1,
         verbose=1,
         evaiter=10,
         metric='sc',
         **stft_kwargs):
    r"""
    Reconstruct spectrogram phase using `Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`_ .

    .. _`Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`:
        https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8552369

    Args:
        spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)`
            (complex input). If a magnitude spectrogram is given, the phase will first be intialized using
            :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input.
        maxiter (int): maximum number of iterations before timing out.
        tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``
        rho (float): non-negative speedup parameter. Small value is preferable when the input spectrogram is noisy (inperfect);
            set it to 1 will behave similar to ``griffin_lim``.  Default: ``0.1``
        verbose (bool): whether to be verbose. Default: :obj:`True`
        evaiter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10``
        metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'``
        **stft_kwargs: other arguments that pass to :func:`torch.stft`.


    Returns:
        A 1d tensor converted from the given spectrogram

    """
    n_fft, proccessed_args = _args_helper(spec, **stft_kwargs)

    istft = partial(_istft, n_fft=n_fft, **proccessed_args)

    if len(spec.shape) == 2:
        X = phase_init(spec, **stft_kwargs)
    else:
        X = torch.stack((spec, torch.zeros_like(spec)), 2)

    x = istft(X)
    Z = X.clone()
    Y = X.clone()
    U = torch.zeros_like(X)

    criterion = nn.MSELoss()
    init_loss = None
    bar_dict = {}
    if metric == 'snr':
        metric_func = SNR
        bar_dict['SNR'] = 0
        metric = metric.upper()
    elif metric == 'ser':
        metric_func = SER
        bar_dict['SER'] = 0
        metric = metric.upper()
    else:
        metric_func = spectral_convergence
        bar_dict['spectral_convergence'] = 0
        metric = 'spectral_convergence'

    with tqdm(total=maxiter, disable=not verbose) as pbar:
        for i in range(maxiter):
            # Pc2
            X[:] = Z - U
            mag = X.pow(2).sum(2).sqrt()
            X *= (spec / F.threshold_(mag, 1e-7, 1e-7)).unsqueeze(-1)

            Y[:] = X + U
            # Pc1
            x[:] = istft(Y)
            reconstruted = torch.stft(x, n_fft, **stft_kwargs)

            Z[:] = (rho * Y + reconstruted) / (1 + rho)
            U += X - Z

            if i % evaiter == evaiter - 1:
                mag = reconstruted.pow(2).sum(2).sqrt()
                bar_dict[metric] = metric_func(mag, spec).item()
                l2_loss = criterion(mag, spec).item()
                pbar.set_postfix(**bar_dict, loss=l2_loss)
                pbar.update(evaiter)

                if not init_loss:
                    init_loss = l2_loss
                elif (
                        previous_loss - l2_loss
                ) / init_loss < tol * evaiter and previous_loss > l2_loss:
                    break
                previous_loss = l2_loss

    return istft(X)
コード例 #7
0
def RTISI_LA(spec,
             look_ahead=-1,
             asymmetric_window=False,
             maxiter=25,
             alpha=0.99,
             verbose=1,
             **stft_kwargs):
    r"""
    Reconstruct spectrogram phase using `Real-Time Iterative Spectrogram Inversion with Look Ahead`_ (RTISI-LA).

    .. _`Real-Time Iterative Spectrogram Inversion with Look Ahead`:
        https://lonce.org/home/Publications/publications/2007_RealtimeSignalReconstruction.pdf


    Args:
        spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude).
        look_ahead (int): how many future frames will be consider. ``-1`` will set it to ``(win_length - 1) / hop_length``,
            ``0`` will disable look-ahead strategy and fall back to original RTISI algorithm. Default: ``-1``
        asymmetric_window (bool): whether to apply asymmetric window on the first iteration for new coming frame.
        maxiter (int): number of iterations for each step.
        alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0``
        verbose (bool): whether to be verbose. Default: :obj:`True`
        **stft_kwargs: other arguments that pass to :func:`torch.stft`.

    Returns:
        A 1d tensor converted from the given spectrogram

    """
    n_fft, proccessed_args = _args_helper(spec, **stft_kwargs)
    copyed_kwargs = stft_kwargs.copy()
    copyed_kwargs['center'] = False

    win_length = proccessed_args['win_length']
    hop_length = proccessed_args['hop_length']
    synth_coeff = proccessed_args['synth_coeff']
    offset = proccessed_args['offset']
    onesided = proccessed_args['onesided']
    normalized = proccessed_args['normalized']
    ola_weight = proccessed_args['ola_weight']
    window = proccessed_args['window']
    if window is None:
        window = torch.hann_window(win_length).to(spec.device)

    num_keep = (win_length - 1) // hop_length
    if look_ahead < 0:
        look_ahead = num_keep

    asym_window1 = spec.new_zeros(win_length)
    for i in range(num_keep):
        asym_window1[(i + 1) * hop_length:] += window.flip(0)[:-(i + 1) *
                                                              hop_length:]
    asym_window1 *= hop_length / (asym_window1 * window).sum() / synth_coeff

    asym_window2 = spec.new_zeros(win_length)
    for i in range(num_keep + 1):
        asym_window2[i *
                     hop_length:] += window.flip(0)[:-i *
                                                    hop_length if i else None]
    asym_window2 *= hop_length / (asym_window2 * window).sum() / synth_coeff

    steps = spec.shape[1]
    xt = spec.new_zeros(steps + num_keep + 2 * look_ahead, n_fft)
    xt_winview = xt[:, offset:offset + win_length]
    spec = F.pad(spec, [look_ahead, look_ahead])

    def irfft(x):
        return torch.irfft(x,
                           1,
                           normalized=normalized,
                           onesided=onesided,
                           signal_sizes=[n_fft] if onesided else None)

    def rfft(x):
        return torch.rfft(x, 1, normalized=normalized, onesided=onesided)

    # initialize first frame with zero phase
    first_frame = spec[:, look_ahead]
    xt_winview[num_keep + look_ahead] = irfft(
        torch.stack((first_frame, torch.zeros_like(first_frame)),
                    -1))[offset:offset + win_length]

    with tqdm(total=steps + look_ahead, disable=not verbose) as pbar:
        for i in range(steps + look_ahead):
            for j in range(maxiter):
                x = _ola(xt_winview[i:i + num_keep + look_ahead + 1].t(),
                         window, hop_length, synth_coeff, ola_weight)
                if asymmetric_window:
                    xt_winview[i + num_keep:i + num_keep + look_ahead + 1] = \
                        x.unfold(0, win_length, hop_length)[num_keep:]

                    xt_winview[i + num_keep:i + num_keep +
                               look_ahead] *= window
                    if j:
                        xt_winview[i + num_keep + look_ahead] *= asym_window2
                    else:
                        xt_winview[i + num_keep + look_ahead] *= asym_window1

                    new_spec = rfft(xt[i + num_keep:i + num_keep + look_ahead +
                                       1]).transpose(0, 1)
                else:
                    new_spec = torch.stft(F.pad(
                        x[num_keep * hop_length - offset:], [0, offset]),
                                          n_fft=n_fft,
                                          **copyed_kwargs)

                if j:
                    new_spec += alpha * (new_spec - pre_spec)
                    pre_spec.copy_(new_spec)
                elif i:
                    new_spec[:, :-1] += alpha * (new_spec[:, :-1] -
                                                 pre_spec[:, 1:])
                    pre_spec.copy_(new_spec)
                else:
                    pre_spec = new_spec.clone()

                mag = F.threshold_(new_spec.pow(2).sum(2).sqrt(), 1e-7, 1e-7)
                new_spec *= (spec[:, i:i + look_ahead + 1] / mag).unsqueeze(-1)

                xt_winview[i + num_keep:i + num_keep + look_ahead + 1] = irfft(
                    new_spec.transpose(0, 1))[:, offset:offset + win_length]

            pbar.update()

    x = _ola(
        xt_winview[num_keep +
                   look_ahead:-look_ahead if look_ahead else None].t(), window,
        hop_length, synth_coeff, ola_weight)
    if proccessed_args['center']:
        x = x[win_length // 2:-win_length // 2]
    else:
        x = F.pad(x, [offset, offset])

    return x
コード例 #8
0
def griffin_lim(spec,
                maxiter: int = 200,
                tol: float = 1e-6,
                alpha: float = 0.99,
                verbose: bool = True,
                evaiter: int = 10,
                metric='sc',
                **stft_kwargs):
    r"""Reconstruct spectrogram phase using the will known `Griffin-Lim`_ algorithm and its variation, `Fast Griffin-Lim`_.


    .. _`Griffin-Lim`: https://pdfs.semanticscholar.org/14bc/876fae55faf5669beb01667a4f3bd324a4f1.pdf
    .. _`Fast Griffin-Lim`: https://perraudin.info/publications/perraudin-note-002.pdf

    Args:
        spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)`
            (complex input). If a magnitude spectrogram is given, the phase will first be intialized using
            :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input.
        maxiter (int): maximum number of iterations before timing out.
        tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``
        alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0``
        verbose (bool): whether to be verbose. Default: :obj:`True`
        evaiter (int): steps size for evaluation. After each step, the function defined in `metric` will evaluate. Default: ``10``
        metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'``
        **stft_kwargs: other arguments that pass to :func:`torch.stft`

    Returns:
        A 1d tensor converted from the given spectrogram

    """
    n_fft, proccessed_args = _args_helper(spec, **stft_kwargs)

    istft = partial(_istft, n_fft=n_fft, **proccessed_args)

    if len(spec.shape) == 2:
        new_spec = phase_init(spec, **stft_kwargs)
    else:
        new_spec = torch.stack((spec, torch.zeros_like(spec)), 2)
    pre_spec = new_spec.clone()
    x = istft(new_spec)

    criterion = nn.MSELoss()
    init_loss = None
    bar_dict = {}
    if metric == 'snr':
        metric_func = SNR
        bar_dict['SNR'] = 0
        metric = metric.upper()
    elif metric == 'ser':
        metric_func = SER
        bar_dict['SER'] = 0
        metric = metric.upper()
    else:
        metric_func = spectral_convergence
        bar_dict['spectral_convergence'] = 0
        metric = 'spectral_convergence'

    with tqdm(total=maxiter, disable=not verbose) as pbar:
        for i in range(maxiter):
            new_spec[:] = torch.stft(x, n_fft, **stft_kwargs)
            new_spec += alpha * (new_spec - pre_spec)
            pre_spec.copy_(new_spec)

            mag = new_spec.pow(2).sum(2).sqrt()
            new_spec *= (spec / F.threshold_(mag, 1e-7, 1e-7)).unsqueeze(-1)
            x[:] = istft(new_spec)

            if i % evaiter == evaiter - 1:
                bar_dict[metric] = metric_func(mag, spec).item()
                l2_loss = criterion(mag, spec).item()
                pbar.set_postfix(**bar_dict, loss=l2_loss)
                pbar.update(evaiter)

                if not init_loss:
                    init_loss = l2_loss
                elif (
                        previous_loss - l2_loss
                ) / init_loss < tol * evaiter and previous_loss > l2_loss:
                    break
                previous_loss = l2_loss

    return x