Пример #1
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print('Num GPUs Available: ', len(tf.config.experimental.list_physical_devices('GPU')))

    raw_dataset = waveform_dataset.get_stft_dataset(
        DATASET_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP
    )

    generator = spec_gan.Generator(channels=2, in_shape=Z_IN_SHAPE)
    discriminator = spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)

    get_waveform = lambda stft:\
        spectral.stft_2_waveform(
            stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]

    save_examples = lambda epoch, real, generated:\
        save_helper.save_wav_data(
            epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform
        )

    stft_gan_model = wgan.WGAN(
        raw_dataset, generator, [discriminator], Z_DIM,
        generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G,
        batch_size=BATCH_SIZE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR,
        fn_save_examples=save_examples
    )

    stft_gan_model.restore('ckpt-100', 1000)
    stft_gan_model.train()
Пример #2
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
    print('Num GPUs Available: ', len(tf.config.experimental.list_physical_devices('GPU')))

    raw_dataset, magnitude_stats, phase_stats =\
        waveform_dataset.get_magnitude_phase_dataset(
            DATASET_PATH, FFT_FRAME_LENGTH, FFT_FRAME_STEP, LOG_MAGNITUDE,
            INSTANTANEOUS_FREQUENCY
        )

    normalized_raw_dataset = []
    pb_i = utils.Progbar(len(raw_dataset))
    for spectogram in raw_dataset:
        norm_mag = waveform_dataset.normalize(spectogram[:, :, 0], *magnitude_stats)
        norm_phase = waveform_dataset.normalize(spectogram[:, :, 1], *phase_stats)

        norm = np.concatenate([np.expand_dims(norm_mag, axis=2),
                               np.expand_dims(norm_phase, axis=2)], axis=-1)
        normalized_raw_dataset.append(norm)
        pb_i.add(1)
    normalized_raw_dataset = np.array(normalized_raw_dataset)

    generator = spec_gan.Generator(channels=2, activation=activations.tanh, in_shape=Z_IN_SHAPE)
    discriminator = spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)

    get_waveform = lambda spectogram:\
        save_helper.get_waveform_from_normalized_spectogram(
            spectogram, [magnitude_stats, phase_stats], FFT_FRAME_LENGTH,
            FFT_FRAME_STEP, LOG_MAGNITUDE, INSTANTANEOUS_FREQUENCY
        )

    save_examples = lambda epoch, real, generated:\
        save_helper.save_wav_data(
            epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform
        )

    spec_phase_gan_model = wgan.WGAN(
        normalized_raw_dataset, generator, [discriminator], Z_DIM,
        generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G,
        batch_size=BATCH_SIZE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR,
        fn_save_examples=save_examples
    )

    spec_phase_gan_model.train()
Пример #3
0
# that mode (or data point) and returns a waveform. Additionally, waveform': [] must
# set, this is where the waveforms are collected.
MODELS = {
    'WaveGAN': {
        'generator': wave_gan.Generator(),
        'checkpoint_path':\
            '_results/representation_study/SpeechMNIST/WaveGAN/training_checkpoints/ckpt-30',
        'preprocess': {
            'unnormalize_magnitude': False,
            'unnormalize_spectogram': False,
        },
        'generate_fn': lambda x: x,
        'waveform': [],
    },
    'STFTGAN_HR': {
        'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]),
        'checkpoint_path':\
            '_results/representation_study/SpeechMNIST/STFTGAN_HR/training_checkpoints/ckpt-30',
        'preprocess': {
            'unnormalize_magnitude': False,
            'unnormalize_spectogram': False,
        },
        'fft_config': 1,
        'generate_fn': lambda stfts: spectral.stft_2_waveform(
            stfts, FFT_FRAME_LENGTHS[1], FFT_FRAME_STEPS[1]
        ),
        'waveform': [],
    },
    'STFTWaveGAN_HR': {
        'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]),
        'checkpoint_path':\