def _plot_stfts(params): # plt.style.use('seaborn') sig = _get_signal() sig.truncate_seconds(3.0) plt.figure(figsize=(7, 5 * len(params))) fig, axs = plt.subplots(len(params), 1) specs = [] for prm in params: sig.stft_params = nussl.STFTParams(*prm) stft = sig.stft() spec = np.squeeze(librosa.amplitude_to_db(np.abs(stft), ref=np.max)) specs.append(spec) shapes = [s.shape for s in specs] f, t = np.max(shapes, axis=0) for ax, spec, prm in zip(axs.flat, specs, params): plt.axes(ax) plt.imshow(spec, aspect='auto', cmap='magma') plt.xlim([0, t]) plt.ylim([0, f]) plt.xlabel('Time Steps') plt.ylabel('Frequency bins') plt.title(f'Win Length = {prm[0]}, Hop Length = {prm[1]}') ax.label_outer() # plt.gca().invert_yaxis() plt.show()
def signal(window_length: int = 2048, hop_length: int = 512, window_type: str = 'sqrt_hann', sample_rate: int = 44100): """ Defines global AudioSignal parameters and builds STFTParams object. Parameters ---------- window_length : int, optional Window length of STFT, by default 2048. hop_length : int, optional Hop length of STFT, by default 512. window_type : str, optional Window type of STFT., by default 'sqrt_hann'. sample_rate : int, optional Sampling rate, by default 44100. Returns ------- tuple Tuple of nussl.STFTParams and sample_rate. """ return (nussl.STFTParams(window_length, hop_length, window_type), sample_rate)
def test_stft_module(combo, one_item): win_length = combo[0] hop_length = int(combo[0] * combo[1]) win_type = combo[2] window = nussl.AudioSignal.get_window(combo[2], win_length) stft_params = nussl.STFTParams(window_length=win_length, hop_length=hop_length, window_type=win_type) representation = ml.networks.modules.STFT(win_length, hop_length=hop_length, window_type=win_type) if not check_COLA(window, win_length, win_length - hop_length): assert True data = one_item['mix_audio'] encoded = representation(data, 'transform') decoded = representation(encoded, 'inverse') encoded = encoded.squeeze(0).permute(1, 0, 2) assert (decoded - data).abs().max() < 1e-6 audio_signal = nussl.AudioSignal(audio_data_array=data.squeeze(0).numpy(), sample_rate=16000, stft_params=stft_params) nussl_magnitude = np.abs(audio_signal.stft()) _encoded = encoded.squeeze(0) cutoff = _encoded.shape[0] // 2 _encoded = _encoded[:cutoff, ...] assert (_encoded - nussl_magnitude).abs().max() < 1e-6
def one_item(scaper_folder): stft_params = nussl.STFTParams(window_length=512, hop_length=128) tfms = transforms.Compose([ transforms.PhaseSensitiveSpectrumApproximation(), transforms.GetAudio(), transforms.ToSeparationModel() ]) dataset = nussl.datasets.Scaper(scaper_folder, transform=tfms, stft_params=stft_params) i = np.random.randint(len(dataset)) data = dataset[i] for k in data: # fake a batch dimension if torch.is_tensor(data[k]): data[k] = data[k].unsqueeze(0) yield data
def simple_sine_data(): nussl.utils.seed(0) folder = 'ignored' stft_params = nussl.STFTParams(window_length=256, hop_length=64) tfm = nussl.datasets.transforms.Compose([ nussl.datasets.transforms.PhaseSensitiveSpectrumApproximation(), nussl.datasets.transforms.MagnitudeWeights(), ]) tensorize = nussl.datasets.transforms.ToSeparationModel() sine_wave_dataset = SineWaves( folder, sample_rate=8000, stft_params=stft_params, transform=tfm, num_sources=2) item = sine_wave_dataset[0] tensorized = tensorize(copy.deepcopy(item)) for key in tensorized: if torch.is_tensor(tensorized[key]): tensorized[key] = tensorized[key].to( DEVICE).float().unsqueeze(0).contiguous() return item, tensorized
return output # - # As a reminder, this dataset makes random mixtures of sine waves with fundamental frequencies # between 110 Hz and 4000 Hz. Let's now set it up with appropriate STFT parameters that result # in 129 frequencies in the spectrogram. # + nussl.utils.seed(0) # make sure this does the same thing each time # We're not reading data, so we can 'ignore' the folder folder = 'ignored' stft_params = nussl.STFTParams(window_length=256, hop_length=64) sine_wave_dataset = SineWaves(folder, sample_rate=8000, stft_params=stft_params) item = sine_wave_dataset[0] def visualize_and_embed(sources, y_axis='mel'): plt.figure(figsize=(10, 4)) plt.subplot(111) nussl.utils.visualize_sources_as_masks(sources, db_cutoff=-60, y_axis=y_axis) plt.tight_layout()
# 3. `window_type`: What sort of windowing function to use? # # These three parameters are grouped into a `namedtuple` object that # belongs to every AudioSignal object. This is the `STFTParams` object: nussl.STFTParams # When we created the audio signals above, the STFT parameters were built on initialization: signal1.stft_params # The STFT parameters are built using helpful defaults based on properties of the audo signal. # 32 millisecond windows are used with an 8 millisecond hop between windows. At 44100 Hz, # this results in 2048 for the window length and 512 for the hop length. The window # type is the `sqrt_hann` window, which generally has better separation performance. # There are many windows that can be used: nussl.constants.ALL_WINDOWS # An AudioSignal's STFT parameters can be set after the fact. Let's change the one for signal1: signal1.stft_params = nussl.STFTParams(window_length=256, hop_length=128) signal1.stft().shape # The shape of the resultant STFT is now different. Note that 256 resulted in 129 # frequencies of analysis per frame. In general, the rule is `(window_length // 2) + 1`. end_time = time.time() time_taken = end_time - start_time print(f'Time taken: {time_taken:.4f} seconds')