def threshold_(potentials, threshold=None): r"""The inplace version of :func:`~threshold` """ if threshold is None: potentials[:-1]=0 else: fn.threshold_(potentials, threshold, 0)
def fire_(potentials, threshold=None): r"""The inplace version of :func:`~fire` """ if threshold is None: potentials[:-1]=0 else: fn.threshold_(potentials, threshold, 0) potentials.sign_()
def threshold(potentials, threshold=None): r"""Applies a threshold on potentials by which all of the values lower or equal to the threshold becomes zero. If :attr:`threshold` is :attr:`None`, only the potentials corresponding to the final time step will survive. Args: potentials (Tensor): The tensor of input potentials. threshold (float): The threshold value. Default: None Returns: Tensor: Thresholded potentials. """ outputs = potentials.clone().detach() if threshold is None: outputs[:-1] = 0 else: fn.threshold_(outputs, threshold, 0) return outputs
def fire(potentials, threshold=None, return_thresholded_potentials=False): r"""Computes the spike-wave tensor from tensor of potentials. If :attr:`threshold` is :attr:`None`, all the neurons emit one spike (if the potential is greater than zero) in the last time step. Args: potentials (Tensor): The tensor of input potentials. threshold (float): Firing threshold. Default: None return_thresholded_potentials (boolean): If True, the tensor of thresholded potentials will be returned as well as the tensor of spike-wave. Default: False Returns: Tensor: Spike-wave tensor. """ thresholded = potentials.clone().detach() if threshold is None: thresholded[:-1] = 0 else: fn.threshold_(thresholded, threshold, 0) if return_thresholded_potentials: return thresholded.sign(), thresholded return thresholded.sign()
def test_threshold_(self): inp = torch.randn(1, 8, 32, 32, device='cuda', dtype=self.dtype) output = F.threshold_(inp, 6, 6)
def ADMM(spec, maxiter=1000, tol=1e-6, rho=0.1, verbose=1, evaiter=10, metric='sc', **stft_kwargs): r""" Reconstruct spectrogram phase using `Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`_ . .. _`Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8552369 Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` (complex input). If a magnitude spectrogram is given, the phase will first be intialized using :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. maxiter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` rho (float): non-negative speedup parameter. Small value is preferable when the input spectrogram is noisy (inperfect); set it to 1 will behave similar to ``griffin_lim``. Default: ``0.1`` verbose (bool): whether to be verbose. Default: :obj:`True` evaiter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **stft_kwargs: other arguments that pass to :func:`torch.stft`. Returns: A 1d tensor converted from the given spectrogram """ n_fft, proccessed_args = _args_helper(spec, **stft_kwargs) istft = partial(_istft, n_fft=n_fft, **proccessed_args) if len(spec.shape) == 2: X = phase_init(spec, **stft_kwargs) else: X = torch.stack((spec, torch.zeros_like(spec)), 2) x = istft(X) Z = X.clone() Y = X.clone() U = torch.zeros_like(X) criterion = nn.MSELoss() init_loss = None bar_dict = {} if metric == 'snr': metric_func = SNR bar_dict['SNR'] = 0 metric = metric.upper() elif metric == 'ser': metric_func = SER bar_dict['SER'] = 0 metric = metric.upper() else: metric_func = spectral_convergence bar_dict['spectral_convergence'] = 0 metric = 'spectral_convergence' with tqdm(total=maxiter, disable=not verbose) as pbar: for i in range(maxiter): # Pc2 X[:] = Z - U mag = X.pow(2).sum(2).sqrt() X *= (spec / F.threshold_(mag, 1e-7, 1e-7)).unsqueeze(-1) Y[:] = X + U # Pc1 x[:] = istft(Y) reconstruted = torch.stft(x, n_fft, **stft_kwargs) Z[:] = (rho * Y + reconstruted) / (1 + rho) U += X - Z if i % evaiter == evaiter - 1: mag = reconstruted.pow(2).sum(2).sqrt() bar_dict[metric] = metric_func(mag, spec).item() l2_loss = criterion(mag, spec).item() pbar.set_postfix(**bar_dict, loss=l2_loss) pbar.update(evaiter) if not init_loss: init_loss = l2_loss elif ( previous_loss - l2_loss ) / init_loss < tol * evaiter and previous_loss > l2_loss: break previous_loss = l2_loss return istft(X)
def RTISI_LA(spec, look_ahead=-1, asymmetric_window=False, maxiter=25, alpha=0.99, verbose=1, **stft_kwargs): r""" Reconstruct spectrogram phase using `Real-Time Iterative Spectrogram Inversion with Look Ahead`_ (RTISI-LA). .. _`Real-Time Iterative Spectrogram Inversion with Look Ahead`: https://lonce.org/home/Publications/publications/2007_RealtimeSignalReconstruction.pdf Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude). look_ahead (int): how many future frames will be consider. ``-1`` will set it to ``(win_length - 1) / hop_length``, ``0`` will disable look-ahead strategy and fall back to original RTISI algorithm. Default: ``-1`` asymmetric_window (bool): whether to apply asymmetric window on the first iteration for new coming frame. maxiter (int): number of iterations for each step. alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` verbose (bool): whether to be verbose. Default: :obj:`True` **stft_kwargs: other arguments that pass to :func:`torch.stft`. Returns: A 1d tensor converted from the given spectrogram """ n_fft, proccessed_args = _args_helper(spec, **stft_kwargs) copyed_kwargs = stft_kwargs.copy() copyed_kwargs['center'] = False win_length = proccessed_args['win_length'] hop_length = proccessed_args['hop_length'] synth_coeff = proccessed_args['synth_coeff'] offset = proccessed_args['offset'] onesided = proccessed_args['onesided'] normalized = proccessed_args['normalized'] ola_weight = proccessed_args['ola_weight'] window = proccessed_args['window'] if window is None: window = torch.hann_window(win_length).to(spec.device) num_keep = (win_length - 1) // hop_length if look_ahead < 0: look_ahead = num_keep asym_window1 = spec.new_zeros(win_length) for i in range(num_keep): asym_window1[(i + 1) * hop_length:] += window.flip(0)[:-(i + 1) * hop_length:] asym_window1 *= hop_length / (asym_window1 * window).sum() / synth_coeff asym_window2 = spec.new_zeros(win_length) for i in range(num_keep + 1): asym_window2[i * hop_length:] += window.flip(0)[:-i * hop_length if i else None] asym_window2 *= hop_length / (asym_window2 * window).sum() / synth_coeff steps = spec.shape[1] xt = spec.new_zeros(steps + num_keep + 2 * look_ahead, n_fft) xt_winview = xt[:, offset:offset + win_length] spec = F.pad(spec, [look_ahead, look_ahead]) def irfft(x): return torch.irfft(x, 1, normalized=normalized, onesided=onesided, signal_sizes=[n_fft] if onesided else None) def rfft(x): return torch.rfft(x, 1, normalized=normalized, onesided=onesided) # initialize first frame with zero phase first_frame = spec[:, look_ahead] xt_winview[num_keep + look_ahead] = irfft( torch.stack((first_frame, torch.zeros_like(first_frame)), -1))[offset:offset + win_length] with tqdm(total=steps + look_ahead, disable=not verbose) as pbar: for i in range(steps + look_ahead): for j in range(maxiter): x = _ola(xt_winview[i:i + num_keep + look_ahead + 1].t(), window, hop_length, synth_coeff, ola_weight) if asymmetric_window: xt_winview[i + num_keep:i + num_keep + look_ahead + 1] = \ x.unfold(0, win_length, hop_length)[num_keep:] xt_winview[i + num_keep:i + num_keep + look_ahead] *= window if j: xt_winview[i + num_keep + look_ahead] *= asym_window2 else: xt_winview[i + num_keep + look_ahead] *= asym_window1 new_spec = rfft(xt[i + num_keep:i + num_keep + look_ahead + 1]).transpose(0, 1) else: new_spec = torch.stft(F.pad( x[num_keep * hop_length - offset:], [0, offset]), n_fft=n_fft, **copyed_kwargs) if j: new_spec += alpha * (new_spec - pre_spec) pre_spec.copy_(new_spec) elif i: new_spec[:, :-1] += alpha * (new_spec[:, :-1] - pre_spec[:, 1:]) pre_spec.copy_(new_spec) else: pre_spec = new_spec.clone() mag = F.threshold_(new_spec.pow(2).sum(2).sqrt(), 1e-7, 1e-7) new_spec *= (spec[:, i:i + look_ahead + 1] / mag).unsqueeze(-1) xt_winview[i + num_keep:i + num_keep + look_ahead + 1] = irfft( new_spec.transpose(0, 1))[:, offset:offset + win_length] pbar.update() x = _ola( xt_winview[num_keep + look_ahead:-look_ahead if look_ahead else None].t(), window, hop_length, synth_coeff, ola_weight) if proccessed_args['center']: x = x[win_length // 2:-win_length // 2] else: x = F.pad(x, [offset, offset]) return x
def griffin_lim(spec, maxiter: int = 200, tol: float = 1e-6, alpha: float = 0.99, verbose: bool = True, evaiter: int = 10, metric='sc', **stft_kwargs): r"""Reconstruct spectrogram phase using the will known `Griffin-Lim`_ algorithm and its variation, `Fast Griffin-Lim`_. .. _`Griffin-Lim`: https://pdfs.semanticscholar.org/14bc/876fae55faf5669beb01667a4f3bd324a4f1.pdf .. _`Fast Griffin-Lim`: https://perraudin.info/publications/perraudin-note-002.pdf Args: spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` (complex input). If a magnitude spectrogram is given, the phase will first be intialized using :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. maxiter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` verbose (bool): whether to be verbose. Default: :obj:`True` evaiter (int): steps size for evaluation. After each step, the function defined in `metric` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **stft_kwargs: other arguments that pass to :func:`torch.stft` Returns: A 1d tensor converted from the given spectrogram """ n_fft, proccessed_args = _args_helper(spec, **stft_kwargs) istft = partial(_istft, n_fft=n_fft, **proccessed_args) if len(spec.shape) == 2: new_spec = phase_init(spec, **stft_kwargs) else: new_spec = torch.stack((spec, torch.zeros_like(spec)), 2) pre_spec = new_spec.clone() x = istft(new_spec) criterion = nn.MSELoss() init_loss = None bar_dict = {} if metric == 'snr': metric_func = SNR bar_dict['SNR'] = 0 metric = metric.upper() elif metric == 'ser': metric_func = SER bar_dict['SER'] = 0 metric = metric.upper() else: metric_func = spectral_convergence bar_dict['spectral_convergence'] = 0 metric = 'spectral_convergence' with tqdm(total=maxiter, disable=not verbose) as pbar: for i in range(maxiter): new_spec[:] = torch.stft(x, n_fft, **stft_kwargs) new_spec += alpha * (new_spec - pre_spec) pre_spec.copy_(new_spec) mag = new_spec.pow(2).sum(2).sqrt() new_spec *= (spec / F.threshold_(mag, 1e-7, 1e-7)).unsqueeze(-1) x[:] = istft(new_spec) if i % evaiter == evaiter - 1: bar_dict[metric] = metric_func(mag, spec).item() l2_loss = criterion(mag, spec).item() pbar.set_postfix(**bar_dict, loss=l2_loss) pbar.update(evaiter) if not init_loss: init_loss = l2_loss elif ( previous_loss - l2_loss ) / init_loss < tol * evaiter and previous_loss > l2_loss: break previous_loss = l2_loss return x