コード例 #1
0
def get_chunks(path, hparams):
    # mel basis
    mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
    # load file
    x_s, _ = prepare_wav(wav_loc=path, hparams=hparams, debug=False)
    # Segmentation params
    min_level_db_floor = -30
    db_delta = 5
    silence_threshold = 0.01
    min_silence_for_spec = 0.05
    max_vocal_for_spec = 1.0,
    min_syllable_length_s = 0.05
    # segment
    results = dynamic_threshold_segmentation(
        x_s,
        hparams.sr,
        n_fft=hparams.n_fft,
        hop_length=hparams.hop_length_samples,
        win_length=hparams.win_length_samples,
        min_level_db_floor=min_level_db_floor,
        db_delta=db_delta,
        ref_level_db=hparams.ref_level_db,
        pre=hparams.preemphasis,
        min_silence_for_spec=min_silence_for_spec,
        max_vocal_for_spec=max_vocal_for_spec,
        silence_threshold=silence_threshold,
        verbose=True,
        min_syllable_length_s=min_syllable_length_s,
        spectral_range=[
            hparams.mel_lower_edge_hertz, hparams.mel_upper_edge_hertz
        ],
    )
    if results is None:
        print('Cannot segment the input file')
        return
    # chunks
    start_times = results["onsets"]
    end_times = results["offsets"]
    chunks_mS = []
    start_samples = []
    end_samples = []
    for start_time, end_time in zip(start_times, end_times):
        start_sample = int(start_time * hparams.sr)
        end_sample = int(end_time * hparams.sr)
        syl = x_s[start_sample:end_sample]

        # To avoid mistakes, reproduce the whole preprocessing pipeline, even (here useless) int casting
        _, mS, _ = process_syllable(syl,
                                    hparams,
                                    mel_basis=mel_basis,
                                    debug=False)
        if mS is None:
            continue
        mS_int = (mS * 255).astype('uint8')
        sample = SpectroDataset.process_mSp(mS_int)

        chunks_mS.append(sample)
        start_samples.append(start_sample)
        end_samples.append(end_sample)
    return x_s, chunks_mS, start_samples, end_samples
コード例 #2
0
def plot_generation(model, hparams, num_examples, savepath):
    # forward pass
    model.eval()
    gen = model.generate(batch_dim=num_examples).cpu().detach().numpy()

    # plot
    dims = gen.shape[2:]
    plt.clf()
    fig, axes = plt.subplots(ncols=num_examples)
    for i in range(num_examples):
        # show the image
        axes[i].matshow(gen[i].reshape(dims), origin="lower")
    for ax in fig.get_axes():
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(f'{savepath}/spectro.pdf')
    plt.close('all')

    # audio
    audios = []
    if hparams is not None:
        mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
        mel_inversion_basis = build_mel_inversion_basis(mel_basis)
        for i in range(num_examples):
            gen_audio = inv_spectrogram_sp(
                gen[i, 0],
                n_fft=hparams.n_fft,
                win_length=hparams.win_length_samples,
                hop_length=hparams.hop_length_samples,
                ref_level_db=hparams.ref_level_db,
                power=hparams.power,
                mel_inversion_basis=mel_inversion_basis)
            audios.append(gen_audio)
            sf.write(f'{savepath}/{i}.wav', gen_audio, samplerate=hparams.sr)
    return {'audios': audios, 'spectros': gen}
コード例 #3
0
def calibrate_db_file(sr, num_mel_bins, n_fft, mel_lower_edge_hertz,
                      mel_upper_edge_hertz, hop_length_ms, win_length_ms,
                      power, ref_level_db, wav_loc):

    hparams = HParams(
        sr=sr,
        num_mel_bins=num_mel_bins,
        n_fft=n_fft,
        mel_lower_edge_hertz=mel_lower_edge_hertz,
        mel_upper_edge_hertz=mel_upper_edge_hertz,
        power=power,  # for spectral inversion
        butter_lowcut=mel_lower_edge_hertz,
        butter_highcut=mel_upper_edge_hertz,
        ref_level_db=ref_level_db,  # 20
        mask_spec=False,
        win_length_ms=win_length_ms,
        hop_length_ms=hop_length_ms,
        mask_spec_kwargs={
            "spec_thresh": 0.9,
            "offset": 1e-10
        },
        n_jobs=1,
        verbosity=1,
        reduce_noise=True,
        noise_reduce_kwargs={
            "n_std_thresh": 2.0,
            "prop_decrease": 0.8
        },
        griffin_lim_iters=50)
    suffix = hparams.__repr__()

    dump_folder = 'dump' / f'{suffix}'
    if os.path.isdir(dump_folder):
        shutil.rmtree(dump_folder)
    os.makedirs(dump_folder)

    #  Read wave
    data, _ = prepare_wav(wav_loc, hparams, debug=True)

    # create spec
    if num_mel_bins is not None:
        mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
    else:
        mel_basis = None
    # melspec, debug_info = spectrogram_librosa(data, hparams, _mel_basis=mel_basis, debug=True)
    _, debug_info = spectrogram_sp(y=data,
                                   sr=hparams.sr,
                                   n_fft=hparams.n_fft,
                                   win_length_ms=hparams.win_length_ms,
                                   hop_length_ms=hparams.hop_length_ms,
                                   ref_level_db=hparams.ref_level_db,
                                   _mel_basis=mel_basis,
                                   pre_emphasis=hparams.preemphasis,
                                   power=hparams.power,
                                   debug=True)

    return {
        'max_db': debug_info['max_db'],
        'min_db': debug_info['min_db']
    } if debug_info is not None else None
コード例 #4
0
def plot_interpolations(model, hparams, dataloader, savepath,
                        num_interpolated_points, method, custom_data):
    # Forward pass
    model.eval()
    if custom_data is None:
        for _, data in enumerate(dataloader):
            x_cuda = cuda_variable(data['input'])
            # Get z
            mu, logvar = model.encode(x_cuda)
            z = model.reparameterize(mu, logvar)
            # Arbitrarily choose start and end points as batch_ind and batch_ind + 1
            start_z = z[:-1]
            end_z = z[1:]
            batch_dim, rgb_dim, h_dim, w_dim = x_cuda.shape
            num_examples = batch_dim - 1
            break
    else:
        x_cuda = cuda_variable(torch.tensor(custom_data['start_data']))
        # Get z
        mu, logvar = model.encode(x_cuda)
        start_z = model.reparameterize(mu, logvar)
        x_cuda = cuda_variable(torch.tensor(custom_data['end_data']))
        # Get z
        mu, logvar = model.encode(x_cuda)
        end_z = model.reparameterize(mu, logvar)
        batch_dim, rgb_dim, h_dim, w_dim = x_cuda.shape
        num_examples = batch_dim

    x_interpolation = np.zeros(
        (num_examples, rgb_dim, h_dim, w_dim, num_interpolated_points))

    ind_interp = 0
    for t in np.linspace(start=0, stop=1, num=num_interpolated_points):
        # Perform interp
        if method == 'linear':
            this_z = start_z * (1 - t) + end_z * t
        elif method == 'constant_radius':
            this_z = constant_radius_interpolation(start_z, end_z, t)
        else:
            raise NotImplementedError
        # Decode z
        x_recon = model.decode(this_z).cpu().detach().numpy()
        x_interpolation[:, :, :, :, ind_interp] = x_recon
        ind_interp = ind_interp + 1

    # Plot
    dims = h_dim, w_dim
    plt.clf()
    fig, axes = plt.subplots(nrows=num_examples, ncols=num_interpolated_points)
    for ind_example in range(num_examples):
        for ind_interp in range(num_interpolated_points):
            # show the image
            axes[ind_example,
                 ind_interp].matshow(x_interpolation[ind_example, :, :, :,
                                                     ind_interp].reshape(dims),
                                     origin="lower")
    for ax in fig.get_axes():
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(f'{savepath}/spectro.pdf')
    plt.close('all')

    # audio
    audios = None
    if hparams is not None:
        mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
        mel_inversion_basis = build_mel_inversion_basis(mel_basis)
        for ind_example in range(num_examples):
            for ind_interp in range(num_interpolated_points):
                audio = inv_spectrogram_sp(
                    x_interpolation[ind_example, 0, :, :, ind_interp],
                    n_fft=hparams.n_fft,
                    win_length=hparams.win_length_samples,
                    hop_length=hparams.hop_length_samples,
                    ref_level_db=hparams.ref_level_db,
                    power=hparams.power,
                    mel_inversion_basis=mel_inversion_basis)
                if audios is None:
                    audios = np.zeros(
                        (num_examples, num_interpolated_points, len(audio)))
                audios[ind_example, ind_interp] = audio
                sf.write(f'{savepath}/{ind_example}_{ind_interp}.wav',
                         audio,
                         samplerate=hparams.sr)
    return {'audios': audios, 'spectros': x_interpolation, 'dims': dims}
コード例 #5
0
def main(config,
         load,
         train):

    # Init model and dataset
    model, dataset_train, dataset_val, optimizer, hparams, config, model_path, config_path = get_model_and_dataset(
        config=config, loading_epoch=load)

    # Training
    best_val_loss = float('inf')
    num_examples_plot = 10
    if train:
        print('##### Train')
        # Copy config file in the save directory before training
        if not load:
            if not os.path.exists(model_path):
                os.makedirs(model_path)
                os.mkdir(f'{model.model_dir}/training_plots/')
                os.mkdir(f'{model.model_dir}/plots/')
            shutil.copy(config_path, f'{model_path}/config.py')

        writer = SummaryWriter(f'{model.model_dir}')

        # Epochs
        for ind_epoch in range(config['num_epochs']):
            train_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_train,
                                              batch_size=config['batch_size'], shuffle=True)
            val_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val,
                                            batch_size=config['batch_size'], shuffle=True)

            train_loss = epoch(model, optimizer, train_dataloader,
                               num_batches=config['num_batches'], training=True)
            val_loss = epoch(model, optimizer, val_dataloader,
                             num_batches=config['num_batches'], training=False)
            writer.add_scalar('train_loss', train_loss, ind_epoch)
            writer.add_scalar('val_loss', val_loss, ind_epoch)

            print(f'Epoch {ind_epoch}:')
            print(f'Train loss {train_loss}:')
            print(f'Val loss {val_loss}:')

            del train_dataloader, val_dataloader

            # if (val_loss < best_val_loss) and (ind_epoch % 200 == 0):
            if ind_epoch % 10 == 0:
                # Save model
                model.save(name=ind_epoch)

                # Plots
                if not os.path.isdir(f'{model.model_dir}/training_plots/{ind_epoch}'):
                    os.mkdir(f'{model.model_dir}/training_plots/{ind_epoch}')
                test_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val,
                                                 batch_size=num_examples_plot, shuffle=True)
                savepath = f'{model.model_dir}/training_plots/{ind_epoch}/reconstructions'
                os.mkdir(savepath)
                plot_reconstruction(model, hparams, test_dataloader, savepath, custom_data=None)
                savepath = f'{model.model_dir}/training_plots/{ind_epoch}/generations'
                os.mkdir(savepath)
                plot_generation(model, hparams, num_examples_plot, savepath)
                del test_dataloader

    # Generations
    print('##### Generate')
    USE_CUSTOM_SAMPLE = False
    # Use own samples for generation
    if USE_CUSTOM_SAMPLE:
        data_source = {
            0: {
                'start': {'path': 'data/raw/source_generation/0.wav'},
                'end': {'path': 'data/raw/source_generation/1.wav'},
            },
            1: {
                'start': {'path': 'data/raw/source_generation/0.wav'},
                'end': {'path': 'data/raw/source_generation/2.wav'},
            },
            2: {
                'start': {'path': 'data/raw/source_generation/1.wav'},
                'end': {'path': 'data/raw/source_generation/2.wav'},
            }
        }
        start_data = []
        end_data = []
        for example_ind, examples_dict in data_source.items():
            for name in ['start', 'end']:
                # read file
                syl, _ = prepare_wav(
                    wav_loc=examples_dict[name]['path'], hparams=hparams, debug=False)
                mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
                # process syllable
                sn, mSp, _ = process_syllable(
                    syl=syl, hparams=hparams, mel_basis=mel_basis, debug=False)
                if name == 'start':
                    start_data.append(SpectroDataset.process_mSp(mSp))
                elif name == 'end':
                    end_data.append(SpectroDataset.process_mSp(mSp))
        all_data = start_data + end_data
        #  Batchify
        start_data = np.stack(start_data)
        end_data = np.stack(end_data)
        all_data = np.stack(all_data)
        custom_data = {
            'start_data': start_data,
            'end_data': end_data,
            'all_data': all_data,
        }
    else:
        custom_data = None

    # Dataloader from the dataset, used for latent space visualisation and to feed data for reconstruction
    # and interpolation if not custom examples are provided
    gen_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val,
                                    batch_size=num_examples_plot, shuffle=True)

    if not os.path.isdir(f'{model.model_dir}/plots'):
        os.mkdir(f'{model.model_dir}/plots')

    # Reconstructions
    savepath = f'{model.model_dir}/plots/reconstructions'
    if not os.path.isdir(savepath):
        os.mkdir(savepath)
    plot_reconstruction(model, hparams, gen_dataloader, savepath, custom_data=custom_data)

    # Sampling
    savepath = f'{model.model_dir}/plots/generations'
    if not os.path.isdir(savepath):
        os.mkdir(savepath)
    plot_generation(model, hparams, num_examples_plot, savepath)

    # Linear interpolations
    savepath = f'{model.model_dir}/plots/linear_interpolations'
    if not os.path.isdir(savepath):
        os.mkdir(savepath)
    plot_interpolations(model, hparams, gen_dataloader, savepath, num_interpolated_points=10, method='linear',
                        custom_data=custom_data)

    # Constant r interpolations
    savepath = f'{model.model_dir}/plots/constant_r_interpolations'
    if not os.path.isdir(savepath):
        os.mkdir(savepath)
    plot_interpolations(model, hparams, gen_dataloader, savepath,
                        num_interpolated_points=10, method='constant_radius', custom_data=custom_data)

    # TODO
    # Translations

    # Check geometric organistation of the latent space per species
    savepath = f'{model.model_dir}/plots/stats'
    if not os.path.isdir(savepath):
        os.mkdir(savepath)
    # latent_space_stats_per_species(model, gen_dataloader, savepath)
    plot_tsne_latent(model, gen_dataloader, savepath)
コード例 #6
0
def main(config_path, loading_epoch, source_path, contamination_path,
         contamination_parameters, method):
    # load model
    model, _, _, _, hparams, _, _, config_path = get_model_and_dataset(
        config=config_path, loading_epoch=loading_epoch)

    # set savepath
    savepath = f'{model.model_dir}/plots/contaminations'
    if not os.path.isdir(savepath):
        os.makedirs(savepath)

    # load files
    source = {}
    source['path'] = source_path
    waveform, chunks, start_samples, end_samples = get_chunks(
        path=source['path'], hparams=hparams)
    source['waveform'] = waveform
    source['chunks'] = chunks
    source['start_samples'] = start_samples
    source['end_samples'] = end_samples
    contamination = {}
    contamination['path'] = contamination_path
    waveform, chunks, start_samples, end_samples = get_chunks(
        path=contamination['path'], hparams=hparams)
    contamination['waveform'] = waveform
    contamination['chunks'] = chunks
    contamination['start_samples'] = start_samples
    contamination['end_samples'] = end_samples

    # Choose which samples to contaminate and by which degree
    contamination_indices = []
    contamination_degrees = []
    xs = []
    ys = []
    p_contamination = contamination_parameters['p_contamination']
    for index, chunk in enumerate(source['chunks']):
        if random.random() < p_contamination:
            contamination_indices.append(index)
            contamination_degrees.append(random.random())
            xs.append(chunk)
            # choose (randomly?) a contaminating syllable
            ys.append(random.choice(contamination['chunks']))
    xs_cuda = cuda_variable(torch.tensor(np.stack(xs)))
    ys_cuda = cuda_variable(torch.tensor(np.stack(ys)))

    # Encode
    mu, logvar = model.encode(xs_cuda)
    x_z = model.reparameterize(mu, logvar)
    mu, logvar = model.encode(ys_cuda)
    y_z = model.reparameterize(mu, logvar)
    z_out = torch.zeros_like(x_z)

    # Contaminate
    for batch_ind, t in enumerate(contamination_degrees):
        if method == 'linear':
            z_out[batch_ind] = x_z[batch_ind] * (1 - t) + y_z[batch_ind] * t
        elif method == 'constant_radius':
            z_out[batch_ind] = constant_radius_interpolation(
                x_z[batch_ind], y_z[batch_ind], t)
    # Decode z
    x_recon = model.decode(z_out).cpu().detach().numpy()

    # Replace contamined samples in original wave
    out_wave = source['waveform']
    mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
    mel_inversion_basis = build_mel_inversion_basis(mel_basis)
    for batch_index, contamination_index in enumerate(contamination_indices):
        new_chunk = x_recon[batch_index, 0]
        new_audio = inv_spectrogram_sp(new_chunk,
                                       n_fft=hparams.n_fft,
                                       win_length=hparams.win_length_samples,
                                       hop_length=hparams.hop_length_samples,
                                       ref_level_db=hparams.ref_level_db,
                                       power=hparams.power,
                                       mel_inversion_basis=mel_inversion_basis)
        start_sample = source['start_samples'][contamination_index]
        end_sample = source['end_samples'][contamination_index]
        length_sample = end_sample - start_sample
        # FAIRE UN FADE ICI
        out_wave[start_sample:end_sample] = new_audio[:length_sample]
    sf.write(f'{savepath}/contamination.wav', out_wave, samplerate=hparams.sr)
コード例 #7
0
def main(debug, sr, num_mel_bins, n_fft, chunk_len, mel_lower_edge_hertz,
         mel_upper_edge_hertz, hop_length_ms, win_length_ms, ref_level_db,
         power):
    # DATASET_ID = 'BIRD_DB_CATH'
    # DATASET_ID = 'Bird_all'
    # DATASET_ID = 'Test'
    # DATASET_ID = 'voizo_all'
    DATASET_ID = 'voizo_all_segmented'
    # DATASET_ID = 'voizo_chunks_test_segmented'
    ind_examples = [200, 400, 500, 600, 800, 1000, 1200, 1400, 1600, 1800]

    # STFT time parameters
    if win_length_ms is None:
        win_length = n_fft
    else:
        win_length = ms_to_sample(win_length_ms, sr)
    if hop_length_ms is None:
        hop_length = win_length // 4
    else:
        hop_length = ms_to_sample(hop_length_ms, sr)

    ################################################################################
    if chunk_len['type'] == 'ms':
        chunk_len_ms = chunk_len['value']
        chunk_len_samples_not_rounded = ms_to_sample(chunk_len_ms, sr)
        chunk_len_win = round(
            (chunk_len_samples_not_rounded - win_length) / hop_length) + 1
        chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length
    elif chunk_len['type'] == 'samples':
        chunk_len_samples_not_rounded = chunk_len['value']
        chunk_len_win = round(
            (chunk_len_samples_not_rounded - win_length) / hop_length) + 1
        chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length
    elif chunk_len['type'] == 'stft_win':
        chunk_len_win = chunk_len['value']
        chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length
    ################################################################################
    print('Chunk length is automatically set to match STFT windows/hop sizes')
    print(
        f'STFT win length: {win_length} samples, {1000 * win_length / sr} ms')
    print(f'STFT hop length: {hop_length} samples, {1000* hop_length / sr} ms')
    print(
        f'Chunk length: {chunk_len_samples} samples, {chunk_len_win} win, {chunk_len_samples * 1000 / sr} ms'
    )

    ################################################################################
    print('Create dataset')
    hparams = HParams(
        sr=sr,
        num_mel_bins=num_mel_bins,
        n_fft=n_fft,
        chunk_len_samples=chunk_len_samples,
        mel_lower_edge_hertz=mel_lower_edge_hertz,
        mel_upper_edge_hertz=mel_upper_edge_hertz,
        power=power,  # for spectral inversion
        butter_lowcut=mel_lower_edge_hertz,
        butter_highcut=mel_upper_edge_hertz,
        ref_level_db=ref_level_db,
        preemphasis=0.97,
        mask_spec=False,
        win_length_samples=win_length,
        hop_length_samples=hop_length,
        mask_spec_kwargs={
            "spec_thresh": 0.9,
            "offset": 1e-10
        },
        reduce_noise=True,
        noise_reduce_kwargs={
            "n_std_thresh": 2.0,
            "prop_decrease": 0.8
        },
        n_jobs=1,
        verbosity=1,
    )
    suffix = hparams.__repr__()

    if debug:
        dump_folder = f'dump/{suffix}'
        if os.path.isdir(dump_folder):
            shutil.rmtree(dump_folder)
        os.makedirs(dump_folder)
    else:
        dump_folder = None

    dataset = DataSet(DATASET_ID, hparams=hparams)
    print(f'Number files: {len(dataset.data_files)}')

    ################################################################################
    print('Create a dataset based upon JSON')
    verbosity = 10
    with Parallel(n_jobs=1, verbose=verbosity) as parallel:
        syllable_dfs = parallel(
            delayed(create_label_df)(
                dataset.data_files[key].data,
                hparams=dataset.hparams,
                labels_to_retain=[],
                unit="syllables",
                dict_features_to_retain=['species'],
                key=key,
            ) for key in tqdm(dataset.data_files.keys()))
    syllable_df = pd.concat(syllable_dfs)

    ################################################################################
    print('Get audio for dataset')
    mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
    mel_inversion_basis = build_mel_inversion_basis(mel_basis)
    counter = 0
    save_loc = DATA_DIR / 'syllables' / f'{DATASET_ID}_{suffix}'
    if os.path.isdir(save_loc):
        raise Exception('already exists')
    os.makedirs(save_loc)
    skipped_counter = 0
    for key in syllable_df.key.unique():
        # load audio (key.unique is for loading large wavfiles only once)
        this_syllable_df = syllable_df[syllable_df.key == key]
        wav_loc = dataset.data_files[key].data['wav_loc']
        print(f'{wav_loc}')
        data, _ = prepare_wav(wav_loc, hparams, debug=debug)
        data = data.astype('float32')
        # process each syllable
        for syll_ind, (st, et) in enumerate(
                zip(this_syllable_df.start_time.values,
                    this_syllable_df.end_time.values)):
            s = data[int(st * hparams.sr):int(et * hparams.sr)]
            sn, mS, debug_info = process_syllable(syl=s,
                                                  hparams=hparams,
                                                  mel_basis=mel_basis,
                                                  debug=debug)
            if sn is None:
                skipped_counter += 1
                continue
            # Save as uint to save space
            mS_int = (mS * 255).astype('uint8')
            save_dict = {
                'mS_int': mS_int,
                'sn': sn,
                'indv': this_syllable_df.indv[syll_ind],
                'label': this_syllable_df.species[syll_ind]
            }
            fname = save_loc / str(counter)
            with open(fname, 'wb') as ff:
                pkl.dump(save_dict, ff)
            counter += 1

            if debug and (counter in ind_examples):
                # normalised audio
                sf.write(f'{dump_folder}/{counter}_sn.wav',
                         sn,
                         samplerate=hparams.sr)
                #  Padded mel db norm spectro
                plt.clf()
                plt.matshow(mS, origin="lower")
                plt.savefig(f'{dump_folder}/{counter}_mS.pdf')
                plt.close()
                audio_reconstruct = inv_spectrogram_sp(
                    mS,
                    n_fft=hparams.n_fft,
                    win_length=hparams.win_length_samples,
                    hop_length=hparams.hop_length_samples,
                    ref_level_db=hparams.ref_level_db,
                    power=hparams.power,
                    mel_inversion_basis=mel_inversion_basis)
                sf.write(f'{dump_folder}/{counter}_mS.wav',
                         audio_reconstruct,
                         samplerate=hparams.sr)

    print(f'Skipped counter: {skipped_counter}')
    #  Save hparams
    print("Save hparams")
    hparams_loc = f'{save_loc}_hparams.pkl'
    with open(hparams_loc, 'wb') as ff:
        pkl.dump(hparams, ff)
    # Print hparams
    print("Print hparams")
    hparams_loc = f'{save_loc}_hparams.txt'
    with open(hparams_loc, 'w') as ff:
        for k, v in hparams.__dict__.items():
            ff.write(f'{k}: {v}\n')
コード例 #8
0
def single_file_test(sr,
                     num_mel_bins,
                     n_fft,
                     mel_lower_edge_hertz,
                     mel_upper_edge_hertz,
                     hop_length_ms,
                     win_length_ms,
                     power,
                     ref_level_db,
                     wav_loc,
                     index=0):

    duration_sec = 10

    hparams = HParams(
        sr=sr,
        num_mel_bins=num_mel_bins,
        n_fft=n_fft,
        mel_lower_edge_hertz=mel_lower_edge_hertz,
        mel_upper_edge_hertz=mel_upper_edge_hertz,
        power=power,  # for spectral inversion
        butter_lowcut=mel_lower_edge_hertz,
        butter_highcut=mel_upper_edge_hertz,
        ref_level_db=ref_level_db,
        mask_spec=False,
        win_length_ms=win_length_ms,
        hop_length_ms=hop_length_ms,
        mask_spec_kwargs={
            "spec_thresh": 0.9,
            "offset": 1e-10
        },
        n_jobs=1,
        verbosity=1,
        reduce_noise=True,
        noise_reduce_kwargs={
            "n_std_thresh": 2.0,
            "prop_decrease": 0.8
        },
    )
    suffix = hparams.__repr__()

    dump_folder = 'dump' / f'{suffix}'
    if not os.path.isdir(dump_folder):
        os.makedirs(dump_folder)

    #  Read wave
    data, _ = prepare_wav(wav_loc, hparams, debug=True)
    data = data[:hparams.sr * duration_sec]

    # create spec
    mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
    mel_inversion_basis = build_mel_inversion_basis(mel_basis)
    # mel_basis = None
    # mel_inversion_basis = None
    plt.clf()
    plt.matshow(mel_basis)
    plt.savefig(f'{dump_folder}/mel_basis.pdf')
    plt.close()
    plt.clf()
    plt.matshow(mel_inversion_basis)
    plt.savefig(f'{dump_folder}/mel_inversion_basis.pdf')
    plt.close()
    _, debug_info = spectrogram_sp(y=data,
                                   sr=hparams.sr,
                                   n_fft=hparams.n_fft,
                                   win_length_ms=hparams.win_length_ms,
                                   hop_length_ms=hparams.hop_length_ms,
                                   ref_level_db=hparams.ref_level_db,
                                   _mel_basis=mel_basis,
                                   pre_emphasis=hparams.preemphasis,
                                   power=hparams.power,
                                   debug=True)

    if debug_info is None:
        print('chunk too short')
        return

    from avgn.signalprocessing.spectrogramming_scipy import griffinlim_sp, _mel_to_linear

    sf.write(f'{dump_folder}/{index}_syllable.wav',
             data,
             samplerate=hparams.sr)

    #  preemphasis y
    sf.write(f'{dump_folder}/{index}_preemphasis_y.wav',
             debug_info['preemphasis_y'],
             samplerate=hparams.sr)

    # S_abs
    plt.clf()
    plt.matshow(debug_info['S_abs'][:, :200], origin="lower")
    plt.savefig(f'{dump_folder}/{index}_S_abs.pdf')
    plt.close()
    S_abs_inv = griffinlim_sp(debug_info['S'], hparams.sr, hparams)
    sf.write(f'{dump_folder}/{index}_S_abs.wav',
             S_abs_inv,
             samplerate=hparams.sr)

    # mel
    plt.clf()
    plt.matshow(debug_info['mel'][:, :200], origin="lower")
    plt.savefig(f'{dump_folder}/{index}_mel.pdf')
    plt.close()
    if mel_basis is not None:
        mel_inv = griffinlim_sp(
            _mel_to_linear(debug_info['mel'],
                           _mel_inverse_basis=mel_inversion_basis), hparams.sr,
            hparams)
        sf.write(f'{dump_folder}/{index}_mel.wav',
                 mel_inv,
                 samplerate=hparams.sr)

    # mel_db
    plt.clf()
    plt.matshow(debug_info['mel_db'][:, :200], origin="lower")
    plt.savefig(f'{dump_folder}/{index}_mel_db.pdf')
    plt.close()
    aa_ = debug_info['mel_db'] + hparams.ref_level_db
    bb_ = _db_to_amplitude(aa_)
    if mel_basis is not None:
        cc_ = _mel_to_linear(bb_, _mel_inverse_basis=mel_inversion_basis)
    else:
        cc_ = bb_
    mel_db_inv = griffinlim_sp(cc_, hparams.sr, hparams)
    sf.write(f'{dump_folder}/{index}_mel_db.wav',
             mel_db_inv,
             samplerate=hparams.sr)

    ##################################
    ##################################
    f, t, S = signal.stft(x=mel_db_inv,
                          fs=hparams.sr,
                          window='hann',
                          nperseg=1024,
                          noverlap=256,
                          nfft=None,
                          detrend=False,
                          return_onesided=True,
                          boundary='zeros',
                          padded=True,
                          axis=-1)
    S_abs = np.abs(S)
    S_view = _amplitude_to_db(S_abs)
    plt.clf()
    plt.matshow(S_view)
    plt.savefig(f'{dump_folder}/{index}_mel_db_RECON.pdf')
    plt.close()
    ##################################
    ##################################

    # mel_db_norm
    plt.clf()
    plt.matshow(debug_info['mel_db_norm'][:, :200], origin="lower")
    plt.savefig(f'{dump_folder}/{index}_mel_db_norm.pdf')
    plt.close()
    mel_db_norm_inv = inv_spectrogram_sp(
        debug_info['mel_db_norm'],
        hparams.sr,
        hparams,
        mel_inversion_basis=mel_inversion_basis)
    sf.write(f'{dump_folder}/{index}_mel_db_norm.wav',
             mel_db_norm_inv,
             samplerate=hparams.sr)
コード例 #9
0
def plot_reconstruction(model, hparams, dataloader, savepath, custom_data):
    # Forward pass
    model.eval()
    if custom_data is None:
        for _, data in enumerate(dataloader):
            x_orig = data['input'].numpy()
            x_cuda = cuda_variable(data['input'])
            x_recon = model.reconstruct(x_cuda).cpu().detach().numpy()
            break
    else:
        x_orig = custom_data['all_data']
        x_cuda = cuda_variable(torch.tensor(custom_data['all_data']))
        x_recon = model.reconstruct(x_cuda).cpu().detach().numpy()
    # Plot
    dims = x_recon.shape[2:]
    num_examples = x_recon.shape[0]
    plt.clf()
    fig, axes = plt.subplots(nrows=2, ncols=num_examples)
    for i in range(num_examples):
        # show the image
        axes[0, i].matshow(x_orig[i].reshape(dims), origin="lower")
        axes[1, i].matshow(x_recon[i].reshape(dims), origin="lower")
    for ax in fig.get_axes():
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(f'{savepath}/spectro.pdf')
    plt.close('all')

    # audio
    original_audios = []
    reconstruction_audios = []
    if hparams is not None:
        mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr)
        mel_inversion_basis = build_mel_inversion_basis(mel_basis)
        for i in range(num_examples):
            original_audio = inv_spectrogram_sp(
                x_orig[i, 0],
                n_fft=hparams.n_fft,
                win_length=hparams.win_length_samples,
                hop_length=hparams.hop_length_samples,
                ref_level_db=hparams.ref_level_db,
                power=hparams.power,
                mel_inversion_basis=mel_inversion_basis)
            recon_audio = inv_spectrogram_sp(
                x_recon[i, 0],
                n_fft=hparams.n_fft,
                win_length=hparams.win_length_samples,
                hop_length=hparams.hop_length_samples,
                ref_level_db=hparams.ref_level_db,
                power=hparams.power,
                mel_inversion_basis=mel_inversion_basis)

            sf.write(f'{savepath}/{i}_original.wav',
                     original_audio,
                     samplerate=hparams.sr)
            sf.write(f'{savepath}/{i}_recon.wav',
                     recon_audio,
                     samplerate=hparams.sr)

            original_audios.append(original_audio)
            reconstruction_audios.append(recon_audio)
    return {
        'original_audios': original_audios,
        'reconstruction_audios': reconstruction_audios,
        'original_spectros': x_orig,
        'reconstruction_spectros': x_recon
    }