def transfer(model_dir, timbre_files_pattern, melody_files_pattern, fine_tune=False): """Use the specified model and the TrainedModelFileAPI to perform timbre transfer """ summary_dir = os.path.join(model_dir, 'summaries', 'examples') summary_writer = tf.summary.create_file_writer(summary_dir) latest_checkpoint = train_util.get_latest_chekpoint(model_dir) step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", "")) print("Melody files: ", glob(melody_files_pattern)) print("Timbre files: ", glob(timbre_files_pattern)) with summary_writer.as_default(): file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI model = file_api_cls(model_dir) for melody in glob(melody_files_pattern): for timbre in glob(timbre_files_pattern): audio = model.transfer(melody, timbre) summaries.audio_summary(audio, step=step, audio_type="transfer", fine_tune=fine_tune, melody=melody, timbre=timbre, model=model_dir)
def reconstruct(model_dir, files_pattern, fine_tune=False): """Use the specified model and the TrainedModelFileAPI to reconstruct files """ summary_dir = os.path.join(model_dir, 'summaries', 'examples') summary_writer = tf.summary.create_file_writer(summary_dir) latest_checkpoint = train_util.get_latest_chekpoint(model_dir) step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", "")) unskewed_loss = UnskewedSpectralLoss(fft_sizes=(2048, 1024, 512, 256, 128, 64), loss_type='L1', log_smooth=[None, 1, 10, 100]) with summary_writer.as_default(): file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI print(file_api_cls) model = file_api_cls(model_dir) for audio_file in glob(files_pattern): original, reconstruction = model.reconstruct(audio_file) loss = unskewed_loss(original, reconstruction) summaries.audio_summary(reconstruction, audio_type="reconstruction", step=step, fine_tune=fine_tune, model=model_dir, audio_file=audio_file, reconstruction_loss=loss) summaries.error_heatmap_summary(original, reconstruction, audio_type="reconstruction", step=step, fine_tune=fine_tune, model=model_dir, audio_file=audio_file, reconstruction_loss=loss)
def sample(self, batch, outputs, step): audio = batch['audio'] audio_gen = outputs['audio_gen'] audio_gen = np.array(audio_gen) # Add audio. summaries.audio_summary(audio_gen, step=step, audio_type="reconstruction", tag=self.tag, dataset=self.tag) summaries.audio_summary(audio, step=step, audio_type="original", tag=self.tag, dataset=self.tag) summaries.error_heatmap_summary(audio, audio_gen, step=step, name=f"{self.tag}/error_heatmap", dataset=self.tag) # Add plots. summaries.waveform_summary(audio, audio_gen, step=step, name=f"{self.tag}/waveform", dataset=self.tag) summaries.spectrogram_summary(audio, audio_gen, step=step, name=f"{self.tag}/spectrogram", dataset=self.tag)
def sample(self, batch, outputs, step): audio = batch['audio'] summaries.audio_summary(audio, step, self._sample_rate, name='audio_original') audio_keys = [ 'midi_audio', 'synth_audio', 'midi_audio2', 'synth_audio2' ] for k in audio_keys: if k in outputs and outputs[k] is not None: summaries.audio_summary(outputs[k], step, self._sample_rate, name=k) summaries.spectrogram_summary(audio, outputs[k], step, tag=k) summaries.waveform_summary(audio, outputs[k], step, name=k) summaries.f0_summary(batch[self._f0_key], outputs[f'{self._f0_key}_pred'], step, name='f0_hz_rec') summaries.pianoroll_summary(outputs, step, 'pianoroll', self._frame_rate, 'pianoroll') summaries.midiae_f0_summary(batch[self._f0_key], outputs, step) ld_rec = f'{self._db_key}_rec' if ld_rec in outputs: summaries.midiae_ld_summary(batch[self._db_key], outputs, step, self._db_key) summaries.midiae_sp_summary(outputs, step)
def sample(self, batch, outputs, step): audio = batch['audio'] audio_gen = outputs['audio_gen'] audio_gen = np.array(audio_gen) # Add audio. summaries.audio_summary( audio_gen, step, self._sample_rate, name='audio_generated') summaries.audio_summary( audio, step, self._sample_rate, name='audio_original') # Add plots. summaries.waveform_summary(audio, audio_gen, step) summaries.spectrogram_summary(audio, audio_gen, step)
def cycle_reconstruct(model_dir, timbre_files_pattern, melody_files_pattern, fine_tune=False): """Use the specified model and the TrainedModelFileAPI to perform timbre interpolation Each melody is used as starting timbre with each timbre file as final timbre """ summary_dir = os.path.join(model_dir, 'summaries', 'cycle') summary_writer = tf.summary.create_file_writer(summary_dir) latest_checkpoint = train_util.get_latest_chekpoint(model_dir) step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", "")) unskewed_loss = UnskewedSpectralLoss(fft_sizes=(2048, 1024, 512, 256, 128, 64), loss_type='L1', log_smooth=[None, 1, 10, 100]) with summary_writer.as_default(): model = TrainedModelFileAPI(model_dir) for melody in glob(melody_files_pattern) + [ 'samples/guitar/AR_Lick4_FN.wav' ]: for timbre in glob(timbre_files_pattern): original_audio, cycled = model.cycle_reconstruct( audio=timbre, intermediate_melody=melody) loss = unskewed_loss(original_audio, cycled) summaries.audio_summary(cycled, audio_type="cycled", audio_file=timbre, intermediate=melody, step=step, cycle_reconstruction_loss=loss, model=model_dir, fine_tune=fine_tune) summaries.error_heatmap_summary(original_audio, cycled, audio_type="cycled", audio_file=timbre, intermediate=melody, step=step, cycle_reconstruction_loss=loss, model=model_dir, fine_tune=fine_tune)
def interpolate(model_dir, timbre_files_pattern, melody_files_pattern, fine_tune=False): """Use the specified model and the TrainedModelFileAPI to perform timbre interpolation Each melody is used as starting timbre with each timbre file as final timbre """ summary_dir = os.path.join(model_dir, 'summaries', 'examples') summary_writer = tf.summary.create_file_writer(summary_dir) latest_checkpoint = train_util.get_latest_chekpoint(model_dir) step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", "")) with summary_writer.as_default(): file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI model = file_api_cls(model_dir) for melody in glob(melody_files_pattern): for timbre in glob(timbre_files_pattern): audio = model.continuous_interpolation(melody, timbre) summaries.audio_summary(audio, audio_type="interpolation", step=step, fine_tune=fine_tune, model=model_dir, melody=melody, timbre=timbre)
def evaluate_or_sample(data_provider, model, mode='eval', save_dir='/tmp/ddsp/training', restore_dir='', batch_size=32, num_batches=50, ckpt_delay_secs=0, run_once=False, run_until_step=0): """Run evaluation loop. Args: data_provider: DataProvider instance. model: Model instance. mode: Whether to 'eval' with metrics or create 'sample' s. save_dir: Path to directory to save summary events. restore_dir: Path to directory with checkpoints, defaults to save_dir. batch_size: Size of each eval/sample batch. num_batches: How many batches to eval from dataset. -1 denotes all batches. ckpt_delay_secs: Time to wait when a new checkpoint was not detected. run_once: Only run evaluation or sampling once. run_until_step: Run until we see a checkpoint with a step greater or equal to the specified value. Ignored if <= 0. Returns: If the mode is 'eval', then returns a dictionary of Tensors keyed by loss type. Otherwise, returns None. """ # Default to restoring from the save directory. restore_dir = save_dir if not restore_dir else restore_dir # Set up the summary writer and metrics. summary_dir = os.path.join(save_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) # Sample continuously and load the newest checkpoint each time checkpoints_iterator = tf.train.checkpoints_iterator( restore_dir, ckpt_delay_secs) # Get the dataset. dataset = data_provider.get_batch(batch_size=batch_size, shuffle=False, repeats=-1) # Get audio sample rate sample_rate = data_provider.sample_rate # Get feature frame rate frame_rate = data_provider.frame_rate latest_losses = None with summary_writer.as_default(): for checkpoint_path in checkpoints_iterator: step = int(checkpoint_path.split('-')[-1]) # Redefine thte dataset iterator each time to make deterministic. dataset_iter = iter(dataset) # Load model. model.restore(checkpoint_path) # Iterate through dataset and make predictions checkpoint_start_time = time.time() for batch_idx in range(1, num_batches + 1): try: start_time = time.time() logging.info('Predicting batch %d of size %d', batch_idx, batch_size) # Predict a batch of audio. batch = next(dataset_iter) if isinstance(data_provider, data.SyntheticNotes): batch['audio'] = model.generate_synthetic_audio(batch) batch['f0_confidence'] = tf.ones_like( batch['f0_hz'])[:, :, 0] batch[ 'loudness_db'] = ddsp.spectral_ops.compute_loudness( batch['audio']) elif isinstance(data_provider, data.ZippedProvider): batch, unused_ss_batch = model.parse_zipped_features( batch) # TODO(jesseengel): Find a way to add losses with training=False. audio = batch['audio'] audio_gen, losses = model(batch, return_losses=True, training=True) outputs = model.get_controls(batch, training=True) # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: loudness_metrics = metrics.LoudnessMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_metrics = metrics.F0Metrics(sample_rate=sample_rate, frame_rate=frame_rate, name='f0_harm') f0_crepe_metrics = metrics.F0CrepeMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_twm_metrics = metrics.F0Metrics( sample_rate=sample_rate, frame_rate=frame_rate, name='f0_twm') avg_losses = { name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) for name in list(losses.keys()) } processor_group = getattr(model, 'processor_group', None) if processor_group is not None: for processor in processor_group.processors: # If using a sinusoidal model, infer f0 with two-way mismatch. if isinstance(processor, ddsp.synths.Sinusoidal): # Run on CPU to avoid running out of memory (not expensive). with tf.device('CPU'): processor_controls = outputs[ processor.name]['controls'] amps = processor_controls['amplitudes'] freqs = processor_controls['frequencies'] twm = ddsp.losses.TWMLoss() # Treat all freqs as candidate f0s. outputs['f0_hz_twm'] = twm.predict_f0( freqs, freqs, amps) logging.info( 'Added f0 estimate from sinusoids.') break # If using a noisy sinusoidal model, infer f0 w/ two-way mismatch. elif isinstance(processor, ddsp.synths.NoisySinusoidal): # Run on CPU to avoid running out of memory (not expensive). with tf.device('CPU'): processor_controls = outputs[ processor.name]['controls'] amps = processor_controls['amplitudes'] freqs = processor_controls['frequencies'] noise_ratios = processor_controls[ 'noise_ratios'] amps = amps * (1.0 - noise_ratios) twm = ddsp.losses.TWMLoss() # Treat all freqs as candidate f0s. outputs['f0_hz_twm'] = twm.predict_f0( freqs, freqs, amps) logging.info( 'Added f0 estimate from sinusoids.') break has_f0_twm = ('f0_hz_twm' in outputs and 'f0_hz' in batch) has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch) logging.info('Prediction took %.1f seconds', time.time() - start_time) if mode == 'sample': start_time = time.time() logging.info('Writing summmaries for batch %d', batch_idx) if audio_gen is not None: audio_gen = np.array(audio_gen) # Add audio. summaries.audio_summary(audio_gen, step, sample_rate, name='audio_generated') summaries.audio_summary(audio, step, sample_rate, name='audio_original') # Add plots. summaries.waveform_summary(audio, audio_gen, step) summaries.spectrogram_summary( audio, audio_gen, step) if has_f0: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step, name='f0_harmonic') if has_f0_twm: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz_twm'], step, name='f0_twm') logging.info( 'Writing batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) elif mode == 'eval': start_time = time.time() logging.info('Calculating metrics for batch %d', batch_idx) if audio_gen is not None: loudness_metrics.update_state(batch, audio_gen) if has_f0: f0_metrics.update_state( batch, outputs['f0_hz']) else: f0_crepe_metrics.update_state(batch, audio_gen) if has_f0_twm: f0_twm_metrics.update_state( batch, outputs['f0_hz_twm']) # Loss. for k, v in losses.items(): avg_losses[k].update_state(v) logging.info( 'Metrics for batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) except tf.errors.OutOfRangeError: logging.info('End of dataset.') break logging.info('All %d batches in checkpoint took %.1f seconds', num_batches, time.time() - checkpoint_start_time) if mode == 'eval': loudness_metrics.flush(step) if has_f0: f0_metrics.flush(step) else: f0_crepe_metrics.flush(step) if has_f0_twm: f0_twm_metrics.flush(step) latest_losses = {} for k, metric in avg_losses.items(): latest_losses[k] = metric.result() tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step) metric.reset_states() summary_writer.flush() if run_once: break if 0 < run_until_step <= step: logging.info( 'Saw checkpoint with step %d, which is greater or equal to' ' `run_until_step` of %d. Exiting.', step, run_until_step) break return latest_losses
def evaluate_or_sample(data_provider, model, mode='eval', save_dir='~/tmp/ddsp/training', restore_dir='', batch_size=32, num_batches=50, ckpt_delay_secs=0, run_once=False, run_until_step=0): """Run evaluation loop. Args: data_provider: DataProvider instance. model: Model instance. mode: Whether to 'eval' with metrics or create 'sample' s. save_dir: Path to directory to save summary events. restore_dir: Path to directory with checkpoints, defaults to save_dir. batch_size: Size of each eval/sample batch. num_batches: How many batches to eval from dataset. -1 denotes all batches. ckpt_delay_secs: Time to wait when a new checkpoint was not detected. run_once: Only run evaluation or sampling once. run_until_step: Run until we see a checkpoint with a step greater or equal to the specified value. Ignored if <= 0. Returns: If the mode is 'eval', then returns a dictionary of Tensors keyed by loss type. Otherwise, returns None. """ # Default to restoring from the save directory. restore_dir = save_dir if not restore_dir else restore_dir # Set up the summary writer and metrics. summary_dir = os.path.join(save_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) # Sample continuously and load the newest checkpoint each time checkpoints_iterator = tf.train.checkpoints_iterator( restore_dir, ckpt_delay_secs) # Get the dataset. dataset = data_provider.get_batch(batch_size=batch_size, shuffle=False, repeats=-1) # Get audio sample rate sample_rate = data_provider.sample_rate # Get feature frame rate frame_rate = data_provider.frame_rate latest_losses = None with summary_writer.as_default(): for checkpoint_path in checkpoints_iterator: step = int(checkpoint_path.split('-')[-1]) # Redefine thte dataset iterator each time to make deterministic. dataset_iter = iter(dataset) # Load model. model.restore(checkpoint_path) # Iterate through dataset and make predictions checkpoint_start_time = time.time() for batch_idx in range(1, num_batches + 1): try: start_time = time.time() logging.info('Predicting batch %d of size %d', batch_idx, batch_size) # Predict a batch of audio. batch = next(dataset_iter) # TODO(jesseengel): Find a way to add losses with training=False. audio = batch['audio'] audio_gen, losses = model(batch, return_losses=True, training=True) audio_gen = np.array(audio_gen) outputs = model.get_controls(batch, training=True) # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: loudness_metrics = metrics.LoudnessMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_metrics = metrics.F0Metrics(sample_rate=sample_rate, frame_rate=frame_rate, name='f0_harm') f0_crepe_metrics = metrics.F0CrepeMetrics( sample_rate=sample_rate, frame_rate=frame_rate) avg_losses = { name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) for name in list(losses.keys()) } has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch) logging.info('Prediction took %.1f seconds', time.time() - start_time) if mode == 'sample': start_time = time.time() logging.info('Writing summmaries for batch %d', batch_idx) # Add audio. summaries.audio_summary(audio_gen, step, sample_rate, name='audio_generated') summaries.audio_summary(audio, step, sample_rate, name='audio_original') # Add plots. summaries.waveform_summary(audio, audio_gen, step) summaries.spectrogram_summary(audio, audio_gen, step) if has_f0: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step, name='f0_harmonic') logging.info( 'Writing batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) elif mode == 'eval': start_time = time.time() logging.info('Calculating metrics for batch %d', batch_idx) loudness_metrics.update_state(batch, audio_gen) if has_f0: f0_metrics.update_state(batch, outputs['f0_hz']) else: f0_crepe_metrics.update_state(batch, audio_gen) # Loss. for k, v in losses.items(): avg_losses[k].update_state(v) logging.info( 'Metrics for batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) except tf.errors.OutOfRangeError: logging.info('End of dataset.') break logging.info('All %d batches in checkpoint took %.1f seconds', num_batches, time.time() - checkpoint_start_time) if mode == 'eval': loudness_metrics.flush(step) if has_f0: f0_metrics.flush(step) else: f0_crepe_metrics.flush(step) latest_losses = {} for k, metric in avg_losses.items(): latest_losses[k] = metric.result() tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step) metric.reset_states() summary_writer.flush() if run_once: break if 0 < run_until_step <= step: logging.info( 'Saw checkpoint with step %d, which is greater or equal to' ' `run_until_step` of %d. Exiting.', step, run_until_step) break return latest_losses