def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print('Num GPUs Available: ', len(tf.config.experimental.list_physical_devices('GPU')))

    raw_dataset = waveform_dataset.get_stft_dataset(
        DATASET_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP
    )

    generator = spec_gan.Generator(channels=2, in_shape=Z_IN_SHAPE)
    discriminator = spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)

    get_waveform = lambda stft:\
        spectral.stft_2_waveform(
            stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]

    save_examples = lambda epoch, real, generated:\
        save_helper.save_wav_data(
            epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform
        )

    stft_gan_model = wgan.WGAN(
        raw_dataset, generator, [discriminator], Z_DIM,
        generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G,
        batch_size=BATCH_SIZE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR,
        fn_save_examples=save_examples
    )

    stft_gan_model.restore('ckpt-100', 1000)
    stft_gan_model.train()
Exemple #2
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'
    print("Num GPUs Available: ",
          len(tf.config.experimental.list_physical_devices('GPU')))

    raw_maestro = waveform_dataset.get_stft_dataset(
        MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH,
        frame_step=FFT_FRAME_STEP).astype(np.float32)
    raw_maestro_conditioning = waveform_dataset.get_waveform_dataset(
        MAESTRO_MIDI_PATH).astype(np.float32)

    generator = midi_conditional_spec_gan.Generator(channels=2)
    discriminator = midi_conditional_spec_gan.Discriminator(
        input_shape=STFT_IMAGE_SHAPE)
    spec_discriminator = midi_conditional_spec_gan.Discriminator(
        input_shape=MAGNITUDE_IMAGE_SHAPE)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4,
                                                   beta_1=0.5,
                                                   beta_2=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4,
                                                       beta_1=0.5,
                                                       beta_2=0.9)

    get_waveform = lambda stft:\
        spectral.stft_2_waveform(
            stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]

    save_examples = lambda epoch, real, generated:\
        save_helper.save_wav_data(
            epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform
        )

    stft_gan_model = conditional_wgan.ConditionalWGAN(
        raw_maestro,
        raw_maestro_conditioning,
        generator, [discriminator, spec_discriminator],
        Z_DIM,
        generator_optimizer,
        discriminator_optimizer,
        discriminator_training_ratio=D_UPDATES_PER_G,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        checkpoint_dir=CHECKPOINT_DIR,
        fn_save_examples=save_examples,
        fn_get_discriminator_input_representations=
        _get_discriminator_input_representations)

    stft_gan_model.train()
def _get_discriminator_input_representations(stft_in):
    """Computes the input representations for the STFTWaveGAN discriminator models,
    returning the input waveform and coresponding spectogram representations

    Args:
        x_in: A batch of stft with shape (-1, time, frequency_dims).

    Returns:
        A tuple containing the waveform and spectogram representaions of
        x_in.
    """

    waveform = spectral.stft_2_waveform(
        stft_in, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP
    )
    waveform = tf.squeeze(waveform)
    waveform = waveform[:, 0:SIGNAL_LENGTH]

    return (stft_in, waveform)
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = ''
    print("Num GPUs Available: ",
          len(tf.config.experimental.list_physical_devices('GPU')))

    raw_maestro_conditioning = waveform_dataset.get_stft_dataset(
        MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP)

    generator = ls_conditional_spec_gan.Generator(channels=2,
                                                  in_shape=(4, 8, 512))

    checkpoint_path = '_results/conditioning/LSC_STFTMagGAN_HR_8192/training_checkpoints/ckpt-11'

    checkpoint = tf.train.Checkpoint(generator=generator)
    checkpoint.restore(checkpoint_path).expect_partial()

    get_waveform = lambda stft:\
        spectral.stft_2_waveform(
            stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]
    get_stft = lambda waveform:\
        spectral.waveform_2_stft(
            waveform, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]

    seed = np.expand_dims(raw_maestro_conditioning[5], 0)

    sequence = []
    for _ in range(N_GENERATIONS):
        z_in = tf.random.uniform((1, 64), -1, 1)
        gen = generator(seed, z_in)
        wav = get_waveform(gen)[0:GENERATION_LENGTH]

        sequence.append(wav)
        wav_cond = wav[CONDITIONING_START_INDEX:]
        seed = np.expand_dims(get_stft(wav_cond), 0)

    audio = np.array(sequence)
    audio = np.squeeze(audio)
    audio = np.reshape(audio, (-1))
    sf.write('stftmaggan_babble.wav', audio, 16000)
Exemple #5
0
def main():
    raw_maestro = waveform_dataset.get_stft_dataset(
        MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP
    )
    raw_maestro_conditioning = waveform_dataset.get_stft_dataset(
        MAESTRO_CONDITIONING_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP
    )

    generator = ls_conditional_spec_gan.Generator(channels=2)
    stft_discriminator = ls_conditional_spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE)
    mag_discriminator = ls_conditional_spec_gan.Discriminator(input_shape=MAGNITUDE_IMAGE_SHAPE)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9)

    get_waveform = lambda stft:\
        spectral.stft_2_waveform(
            stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP
        )[0]

    save_examples = lambda epoch, real, generated:\
        save_helper.save_wav_data(
            epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform
        )

    stft_mag_gan_model = conditional_wgan.ConditionalWGAN(
        (raw_maestro, raw_maestro_conditioning), [SPECTOGRAM_IMAGE_SHAPE, MAGNITUDE_IMAGE_SHAPE],
        [(-1, 64, 256, 2), (-1, 64, 256, 1)], generator,
        [stft_discriminator, mag_discriminator], Z_DIM, generator_optimizer,
        discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G,
        batch_size=BATCH_SIZE, epochs=EPOCHS, lambdas=CRITIC_WEIGHTINGS,
        checkpoint_dir=CHECKPOINT_DIR, fn_save_examples=save_examples,
        fn_get_discriminator_input_representations=_get_discriminator_input_representations
    )

    stft_mag_gan_model.train()
Exemple #6
0
         'unnormalize_spectogram': False,
     },
     'generate_fn': lambda x: x,
     'waveform': [],
 },
 'STFTGAN_HR': {
     'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]),
     'checkpoint_path':\
         '_results/representation_study/SpeechMNIST/STFTGAN_HR/training_checkpoints/ckpt-30',
     'preprocess': {
         'unnormalize_magnitude': False,
         'unnormalize_spectogram': False,
     },
     'fft_config': 1,
     'generate_fn': lambda stfts: spectral.stft_2_waveform(
         stfts, FFT_FRAME_LENGTHS[1], FFT_FRAME_STEPS[1]
     ),
     'waveform': [],
 },
 'STFTWaveGAN_HR': {
     'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]),
     'checkpoint_path':\
         '_results/representation_study/SpeechMNIST/STFTWaveGAN_HR/training_checkpoints/ckpt-30',
     'preprocess': {
         'unnormalize_magnitude': False,
         'unnormalize_spectogram': False,
     },
     'fft_config': 1,
     'generate_fn': lambda stfts: spectral.stft_2_waveform(
         stfts, FFT_FRAME_LENGTHS[1], FFT_FRAME_STEPS[1]
     ),
     .stft_2_waveform(representation, window_length, window_step)[0],
     'distort_representation':
     distortion_helper.distort_multiple_channel_representation,
 },
 '(mel) STFT': {
     'requires_fft_params':
     True,
     'waveform_2_representation':
     lambda waveform, window_length,
     window_step, n_mel_bins: spectral.waveform_2_stft(
         waveform, window_length, window_step, n_mel_bins *
         MEL_BIN_MULTIPLIER, MEL_LOWER_HERTZ_EDGE, MEL_UPPER_HERTZ_EDGE)[0],
     'representation_2_waveform':
     lambda representation, window_length,
     window_step, n_mel_bins: spectral.stft_2_waveform(
         representation, window_length, window_step, n_mel_bins *
         MEL_BIN_MULTIPLIER, MEL_LOWER_HERTZ_EDGE, MEL_UPPER_HERTZ_EDGE)[0],
     'distort_representation':
     distortion_helper.distort_multiple_channel_representation,
 },
 'Mag': {
     'requires_fft_params':
     True,
     'waveform_2_representation':
     lambda waveform, window_length, window_step, n_mel_bins: spectral.
     waveform_2_magnitude(
         waveform, window_length, window_step, log_magnitude=LOG_MAGNITUDE)[
             0],
     'representation_2_waveform':
     lambda representation, window_length, window_step, n_mel_bins: spectral
     .magnitude_2_waveform(representation,