예제 #1
0
def _plot_stfts(params):
    # plt.style.use('seaborn')
    sig = _get_signal()
    sig.truncate_seconds(3.0)

    plt.figure(figsize=(7, 5 * len(params)))
    fig, axs = plt.subplots(len(params), 1)

    specs = []
    for prm in params:
        sig.stft_params = nussl.STFTParams(*prm)
        stft = sig.stft()
        spec  = np.squeeze(librosa.amplitude_to_db(np.abs(stft), ref=np.max))
        specs.append(spec)

    shapes = [s.shape for s in specs]
    f, t = np.max(shapes, axis=0)

    for ax, spec, prm in zip(axs.flat, specs, params):
        plt.axes(ax)
        plt.imshow(spec, aspect='auto', cmap='magma')
        plt.xlim([0, t])
        plt.ylim([0, f])
        plt.xlabel('Time Steps')
        plt.ylabel('Frequency bins')
        plt.title(f'Win Length = {prm[0]}, Hop Length = {prm[1]}')
        ax.label_outer()
        # plt.gca().invert_yaxis()

    plt.show()
예제 #2
0
def signal(window_length: int = 2048,
           hop_length: int = 512,
           window_type: str = 'sqrt_hann',
           sample_rate: int = 44100):
    """
    Defines global AudioSignal parameters and
    builds STFTParams object.

    Parameters
    ----------
    window_length : int, optional
        Window length of STFT, by default 2048.
    hop_length : int, optional
        Hop length of STFT, by default 512.
    window_type : str, optional
        Window type of STFT., by default 'sqrt_hann'.
    sample_rate : int, optional
        Sampling rate, by default 44100.

    Returns
    -------
    tuple
        Tuple of nussl.STFTParams and sample_rate.
    """
    return (nussl.STFTParams(window_length, hop_length,
                             window_type), sample_rate)
예제 #3
0
def test_stft_module(combo, one_item):
    win_length = combo[0]
    hop_length = int(combo[0] * combo[1])
    win_type = combo[2]
    window = nussl.AudioSignal.get_window(combo[2], win_length)
    stft_params = nussl.STFTParams(window_length=win_length,
                                   hop_length=hop_length,
                                   window_type=win_type)

    representation = ml.networks.modules.STFT(win_length,
                                              hop_length=hop_length,
                                              window_type=win_type)

    if not check_COLA(window, win_length, win_length - hop_length):
        assert True

    data = one_item['mix_audio']

    encoded = representation(data, 'transform')
    decoded = representation(encoded, 'inverse')
    encoded = encoded.squeeze(0).permute(1, 0, 2)

    assert (decoded - data).abs().max() < 1e-6

    audio_signal = nussl.AudioSignal(audio_data_array=data.squeeze(0).numpy(),
                                     sample_rate=16000,
                                     stft_params=stft_params)
    nussl_magnitude = np.abs(audio_signal.stft())
    _encoded = encoded.squeeze(0)
    cutoff = _encoded.shape[0] // 2
    _encoded = _encoded[:cutoff, ...]
    assert (_encoded - nussl_magnitude).abs().max() < 1e-6
예제 #4
0
def one_item(scaper_folder):
    stft_params = nussl.STFTParams(window_length=512, hop_length=128)
    tfms = transforms.Compose([
        transforms.PhaseSensitiveSpectrumApproximation(),
        transforms.GetAudio(),
        transforms.ToSeparationModel()
    ])
    dataset = nussl.datasets.Scaper(scaper_folder,
                                    transform=tfms,
                                    stft_params=stft_params)
    i = np.random.randint(len(dataset))
    data = dataset[i]
    for k in data:
        # fake a batch dimension
        if torch.is_tensor(data[k]):
            data[k] = data[k].unsqueeze(0)
    yield data
예제 #5
0
def simple_sine_data():
    nussl.utils.seed(0)
    folder = 'ignored'

    stft_params = nussl.STFTParams(window_length=256, hop_length=64)
    tfm = nussl.datasets.transforms.Compose([
        nussl.datasets.transforms.PhaseSensitiveSpectrumApproximation(),
        nussl.datasets.transforms.MagnitudeWeights(),
    ])
    tensorize = nussl.datasets.transforms.ToSeparationModel()
    sine_wave_dataset = SineWaves(
        folder, sample_rate=8000, stft_params=stft_params,
        transform=tfm, num_sources=2)
    item = sine_wave_dataset[0]
    tensorized = tensorize(copy.deepcopy(item))

    for key in tensorized:
        if torch.is_tensor(tensorized[key]):
            tensorized[key] = tensorized[key].to(
                DEVICE).float().unsqueeze(0).contiguous()

    return item, tensorized
예제 #6
0
파일: training.py 프로젝트: speechdnn/nussl
        return output


# -

# As a reminder, this dataset makes random mixtures of sine waves with fundamental frequencies
# between 110 Hz and 4000 Hz. Let's now set it up with appropriate STFT parameters that result
# in 129 frequencies in the spectrogram.

# +
nussl.utils.seed(0)  # make sure this does the same thing each time

# We're not reading data, so we can 'ignore' the folder
folder = 'ignored'

stft_params = nussl.STFTParams(window_length=256, hop_length=64)

sine_wave_dataset = SineWaves(folder,
                              sample_rate=8000,
                              stft_params=stft_params)

item = sine_wave_dataset[0]


def visualize_and_embed(sources, y_axis='mel'):
    plt.figure(figsize=(10, 4))
    plt.subplot(111)
    nussl.utils.visualize_sources_as_masks(sources,
                                           db_cutoff=-60,
                                           y_axis=y_axis)
    plt.tight_layout()
예제 #7
0
# 3. `window_type`: What sort of windowing function to use?
#
# These three parameters are grouped into a `namedtuple` object that
# belongs to every AudioSignal object. This is the `STFTParams` object:

nussl.STFTParams

# When we created the audio signals above, the STFT parameters were built on initialization:

signal1.stft_params

# The STFT parameters are built using helpful defaults based on properties of the audo signal.
# 32 millisecond windows are used with an 8 millisecond hop between windows. At 44100 Hz,
# this results in 2048 for the window length and 512 for the hop length. The window
# type is the `sqrt_hann` window, which generally has better separation performance.
# There are many windows that can be used:

nussl.constants.ALL_WINDOWS

# An AudioSignal's STFT parameters can be set after the fact. Let's change the one for signal1:

signal1.stft_params = nussl.STFTParams(window_length=256, hop_length=128)
signal1.stft().shape

# The shape of the resultant STFT is now different. Note that 256 resulted in 129
# frequencies of analysis per frame. In general, the rule is `(window_length // 2) + 1`.

end_time = time.time()
time_taken = end_time - start_time
print(f'Time taken: {time_taken:.4f} seconds')