def main(): os.environ['CUDA_VISIBLE_DEVICES'] = '0' print('Num GPUs Available: ', len(tf.config.experimental.list_physical_devices('GPU'))) raw_dataset = waveform_dataset.get_stft_dataset( DATASET_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP ) generator = spec_gan.Generator(channels=2, in_shape=Z_IN_SHAPE) discriminator = spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE) generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) get_waveform = lambda stft:\ spectral.stft_2_waveform( stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP )[0] save_examples = lambda epoch, real, generated:\ save_helper.save_wav_data( epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform ) stft_gan_model = wgan.WGAN( raw_dataset, generator, [discriminator], Z_DIM, generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G, batch_size=BATCH_SIZE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR, fn_save_examples=save_examples ) stft_gan_model.restore('ckpt-100', 1000) stft_gan_model.train()
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = '1' print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) raw_maestro = waveform_dataset.get_stft_dataset( MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP).astype(np.float32) raw_maestro_conditioning = waveform_dataset.get_waveform_dataset( MAESTRO_MIDI_PATH).astype(np.float32) generator = midi_conditional_spec_gan.Generator(channels=2) discriminator = midi_conditional_spec_gan.Discriminator( input_shape=STFT_IMAGE_SHAPE) spec_discriminator = midi_conditional_spec_gan.Discriminator( input_shape=MAGNITUDE_IMAGE_SHAPE) generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) get_waveform = lambda stft:\ spectral.stft_2_waveform( stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP )[0] save_examples = lambda epoch, real, generated:\ save_helper.save_wav_data( epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform ) stft_gan_model = conditional_wgan.ConditionalWGAN( raw_maestro, raw_maestro_conditioning, generator, [discriminator, spec_discriminator], Z_DIM, generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G, batch_size=BATCH_SIZE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR, fn_save_examples=save_examples, fn_get_discriminator_input_representations= _get_discriminator_input_representations) stft_gan_model.train()
def _get_discriminator_input_representations(stft_in): """Computes the input representations for the STFTWaveGAN discriminator models, returning the input waveform and coresponding spectogram representations Args: x_in: A batch of stft with shape (-1, time, frequency_dims). Returns: A tuple containing the waveform and spectogram representaions of x_in. """ waveform = spectral.stft_2_waveform( stft_in, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP ) waveform = tf.squeeze(waveform) waveform = waveform[:, 0:SIGNAL_LENGTH] return (stft_in, waveform)
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = '' print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) raw_maestro_conditioning = waveform_dataset.get_stft_dataset( MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP) generator = ls_conditional_spec_gan.Generator(channels=2, in_shape=(4, 8, 512)) checkpoint_path = '_results/conditioning/LSC_STFTMagGAN_HR_8192/training_checkpoints/ckpt-11' checkpoint = tf.train.Checkpoint(generator=generator) checkpoint.restore(checkpoint_path).expect_partial() get_waveform = lambda stft:\ spectral.stft_2_waveform( stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP )[0] get_stft = lambda waveform:\ spectral.waveform_2_stft( waveform, FFT_FRAME_LENGTH, FFT_FRAME_STEP )[0] seed = np.expand_dims(raw_maestro_conditioning[5], 0) sequence = [] for _ in range(N_GENERATIONS): z_in = tf.random.uniform((1, 64), -1, 1) gen = generator(seed, z_in) wav = get_waveform(gen)[0:GENERATION_LENGTH] sequence.append(wav) wav_cond = wav[CONDITIONING_START_INDEX:] seed = np.expand_dims(get_stft(wav_cond), 0) audio = np.array(sequence) audio = np.squeeze(audio) audio = np.reshape(audio, (-1)) sf.write('stftmaggan_babble.wav', audio, 16000)
def main(): raw_maestro = waveform_dataset.get_stft_dataset( MAESTRO_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP ) raw_maestro_conditioning = waveform_dataset.get_stft_dataset( MAESTRO_CONDITIONING_PATH, frame_length=FFT_FRAME_LENGTH, frame_step=FFT_FRAME_STEP ) generator = ls_conditional_spec_gan.Generator(channels=2) stft_discriminator = ls_conditional_spec_gan.Discriminator(input_shape=SPECTOGRAM_IMAGE_SHAPE) mag_discriminator = ls_conditional_spec_gan.Discriminator(input_shape=MAGNITUDE_IMAGE_SHAPE) generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, beta_2=0.9) get_waveform = lambda stft:\ spectral.stft_2_waveform( stft, FFT_FRAME_LENGTH, FFT_FRAME_STEP )[0] save_examples = lambda epoch, real, generated:\ save_helper.save_wav_data( epoch, real, generated, SAMPLING_RATE, RESULT_DIR, get_waveform ) stft_mag_gan_model = conditional_wgan.ConditionalWGAN( (raw_maestro, raw_maestro_conditioning), [SPECTOGRAM_IMAGE_SHAPE, MAGNITUDE_IMAGE_SHAPE], [(-1, 64, 256, 2), (-1, 64, 256, 1)], generator, [stft_discriminator, mag_discriminator], Z_DIM, generator_optimizer, discriminator_optimizer, discriminator_training_ratio=D_UPDATES_PER_G, batch_size=BATCH_SIZE, epochs=EPOCHS, lambdas=CRITIC_WEIGHTINGS, checkpoint_dir=CHECKPOINT_DIR, fn_save_examples=save_examples, fn_get_discriminator_input_representations=_get_discriminator_input_representations ) stft_mag_gan_model.train()
'unnormalize_spectogram': False, }, 'generate_fn': lambda x: x, 'waveform': [], }, 'STFTGAN_HR': { 'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]), 'checkpoint_path':\ '_results/representation_study/SpeechMNIST/STFTGAN_HR/training_checkpoints/ckpt-30', 'preprocess': { 'unnormalize_magnitude': False, 'unnormalize_spectogram': False, }, 'fft_config': 1, 'generate_fn': lambda stfts: spectral.stft_2_waveform( stfts, FFT_FRAME_LENGTHS[1], FFT_FRAME_STEPS[1] ), 'waveform': [], }, 'STFTWaveGAN_HR': { 'generator': spec_gan.Generator(channels=2, in_shape=[4, 8, 1024]), 'checkpoint_path':\ '_results/representation_study/SpeechMNIST/STFTWaveGAN_HR/training_checkpoints/ckpt-30', 'preprocess': { 'unnormalize_magnitude': False, 'unnormalize_spectogram': False, }, 'fft_config': 1, 'generate_fn': lambda stfts: spectral.stft_2_waveform( stfts, FFT_FRAME_LENGTHS[1], FFT_FRAME_STEPS[1] ),
.stft_2_waveform(representation, window_length, window_step)[0], 'distort_representation': distortion_helper.distort_multiple_channel_representation, }, '(mel) STFT': { 'requires_fft_params': True, 'waveform_2_representation': lambda waveform, window_length, window_step, n_mel_bins: spectral.waveform_2_stft( waveform, window_length, window_step, n_mel_bins * MEL_BIN_MULTIPLIER, MEL_LOWER_HERTZ_EDGE, MEL_UPPER_HERTZ_EDGE)[0], 'representation_2_waveform': lambda representation, window_length, window_step, n_mel_bins: spectral.stft_2_waveform( representation, window_length, window_step, n_mel_bins * MEL_BIN_MULTIPLIER, MEL_LOWER_HERTZ_EDGE, MEL_UPPER_HERTZ_EDGE)[0], 'distort_representation': distortion_helper.distort_multiple_channel_representation, }, 'Mag': { 'requires_fft_params': True, 'waveform_2_representation': lambda waveform, window_length, window_step, n_mel_bins: spectral. waveform_2_magnitude( waveform, window_length, window_step, log_magnitude=LOG_MAGNITUDE)[ 0], 'representation_2_waveform': lambda representation, window_length, window_step, n_mel_bins: spectral .magnitude_2_waveform(representation,