def __init__(self, sample_rate, frame_rate, run_f0_crepe=True): super().__init__(sample_rate, frame_rate) self._loudness_metrics = metrics.LoudnessMetrics( sample_rate=sample_rate, frame_rate=frame_rate) self._f0_metrics = metrics.F0Metrics(sample_rate=sample_rate, frame_rate=frame_rate) self._run_f0_crepe = run_f0_crepe if self._run_f0_crepe: self._f0_crepe_metrics = metrics.F0CrepeMetrics( sample_rate=sample_rate, frame_rate=frame_rate)
def test_loudness_metrics_has_expected_values(self): loudness_metrics = ddsp_metrics.LoudnessMetrics( self.sample_rate, self.frame_rate) # Dummy batch 1: known noise features vs. known noise audio loudness_metrics.update_state(self.batch_of_noise_feats, self.batch_of_noise) self.assertAllClose(loudness_metrics.metrics['loudness_db'].result(), 0) # Dummy batch 2: known noise features vs. quiet batch of sin audio loudness_metrics.update_state(self.batch_of_noise_feats, self.batch_of_sin) self.assertGreater(loudness_metrics.metrics['loudness_db'].result(), 0) loudness_metrics.flush(step=1)
def evaluate_or_sample(data_provider, model, mode='eval', save_dir='/tmp/ddsp/training', restore_dir='', batch_size=32, num_batches=50, ckpt_delay_secs=0, run_once=False, run_until_step=0): """Run evaluation loop. Args: data_provider: DataProvider instance. model: Model instance. mode: Whether to 'eval' with metrics or create 'sample' s. save_dir: Path to directory to save summary events. restore_dir: Path to directory with checkpoints, defaults to save_dir. batch_size: Size of each eval/sample batch. num_batches: How many batches to eval from dataset. -1 denotes all batches. ckpt_delay_secs: Time to wait when a new checkpoint was not detected. run_once: Only run evaluation or sampling once. run_until_step: Run until we see a checkpoint with a step greater or equal to the specified value. Ignored if <= 0. Returns: If the mode is 'eval', then returns a dictionary of Tensors keyed by loss type. Otherwise, returns None. """ # Default to restoring from the save directory. restore_dir = save_dir if not restore_dir else restore_dir # Set up the summary writer and metrics. summary_dir = os.path.join(save_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) # Sample continuously and load the newest checkpoint each time checkpoints_iterator = tf.train.checkpoints_iterator( restore_dir, ckpt_delay_secs) # Get the dataset. dataset = data_provider.get_batch(batch_size=batch_size, shuffle=False, repeats=-1) # Get audio sample rate sample_rate = data_provider.sample_rate # Get feature frame rate frame_rate = data_provider.frame_rate latest_losses = None with summary_writer.as_default(): for checkpoint_path in checkpoints_iterator: step = int(checkpoint_path.split('-')[-1]) # Redefine thte dataset iterator each time to make deterministic. dataset_iter = iter(dataset) # Load model. model.restore(checkpoint_path) # Iterate through dataset and make predictions checkpoint_start_time = time.time() for batch_idx in range(1, num_batches + 1): try: start_time = time.time() logging.info('Predicting batch %d of size %d', batch_idx, batch_size) # Predict a batch of audio. batch = next(dataset_iter) if isinstance(data_provider, data.SyntheticNotes): batch['audio'] = model.generate_synthetic_audio(batch) batch['f0_confidence'] = tf.ones_like( batch['f0_hz'])[:, :, 0] batch[ 'loudness_db'] = ddsp.spectral_ops.compute_loudness( batch['audio']) elif isinstance(data_provider, data.ZippedProvider): batch, unused_ss_batch = model.parse_zipped_features( batch) # TODO(jesseengel): Find a way to add losses with training=False. audio = batch['audio'] audio_gen, losses = model(batch, return_losses=True, training=True) outputs = model.get_controls(batch, training=True) # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: loudness_metrics = metrics.LoudnessMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_metrics = metrics.F0Metrics(sample_rate=sample_rate, frame_rate=frame_rate, name='f0_harm') f0_crepe_metrics = metrics.F0CrepeMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_twm_metrics = metrics.F0Metrics( sample_rate=sample_rate, frame_rate=frame_rate, name='f0_twm') avg_losses = { name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) for name in list(losses.keys()) } processor_group = getattr(model, 'processor_group', None) if processor_group is not None: for processor in processor_group.processors: # If using a sinusoidal model, infer f0 with two-way mismatch. if isinstance(processor, ddsp.synths.Sinusoidal): # Run on CPU to avoid running out of memory (not expensive). with tf.device('CPU'): processor_controls = outputs[ processor.name]['controls'] amps = processor_controls['amplitudes'] freqs = processor_controls['frequencies'] twm = ddsp.losses.TWMLoss() # Treat all freqs as candidate f0s. outputs['f0_hz_twm'] = twm.predict_f0( freqs, freqs, amps) logging.info( 'Added f0 estimate from sinusoids.') break # If using a noisy sinusoidal model, infer f0 w/ two-way mismatch. elif isinstance(processor, ddsp.synths.NoisySinusoidal): # Run on CPU to avoid running out of memory (not expensive). with tf.device('CPU'): processor_controls = outputs[ processor.name]['controls'] amps = processor_controls['amplitudes'] freqs = processor_controls['frequencies'] noise_ratios = processor_controls[ 'noise_ratios'] amps = amps * (1.0 - noise_ratios) twm = ddsp.losses.TWMLoss() # Treat all freqs as candidate f0s. outputs['f0_hz_twm'] = twm.predict_f0( freqs, freqs, amps) logging.info( 'Added f0 estimate from sinusoids.') break has_f0_twm = ('f0_hz_twm' in outputs and 'f0_hz' in batch) has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch) logging.info('Prediction took %.1f seconds', time.time() - start_time) if mode == 'sample': start_time = time.time() logging.info('Writing summmaries for batch %d', batch_idx) if audio_gen is not None: audio_gen = np.array(audio_gen) # Add audio. summaries.audio_summary(audio_gen, step, sample_rate, name='audio_generated') summaries.audio_summary(audio, step, sample_rate, name='audio_original') # Add plots. summaries.waveform_summary(audio, audio_gen, step) summaries.spectrogram_summary( audio, audio_gen, step) if has_f0: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step, name='f0_harmonic') if has_f0_twm: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz_twm'], step, name='f0_twm') logging.info( 'Writing batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) elif mode == 'eval': start_time = time.time() logging.info('Calculating metrics for batch %d', batch_idx) if audio_gen is not None: loudness_metrics.update_state(batch, audio_gen) if has_f0: f0_metrics.update_state( batch, outputs['f0_hz']) else: f0_crepe_metrics.update_state(batch, audio_gen) if has_f0_twm: f0_twm_metrics.update_state( batch, outputs['f0_hz_twm']) # Loss. for k, v in losses.items(): avg_losses[k].update_state(v) logging.info( 'Metrics for batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) except tf.errors.OutOfRangeError: logging.info('End of dataset.') break logging.info('All %d batches in checkpoint took %.1f seconds', num_batches, time.time() - checkpoint_start_time) if mode == 'eval': loudness_metrics.flush(step) if has_f0: f0_metrics.flush(step) else: f0_crepe_metrics.flush(step) if has_f0_twm: f0_twm_metrics.flush(step) latest_losses = {} for k, metric in avg_losses.items(): latest_losses[k] = metric.result() tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step) metric.reset_states() summary_writer.flush() if run_once: break if 0 < run_until_step <= step: logging.info( 'Saw checkpoint with step %d, which is greater or equal to' ' `run_until_step` of %d. Exiting.', step, run_until_step) break return latest_losses
def evaluate_or_sample(data_provider, model, mode='eval', save_dir='~/tmp/ddsp/training', restore_dir='', batch_size=32, num_batches=50, ckpt_delay_secs=0, run_once=False, run_until_step=0): """Run evaluation loop. Args: data_provider: DataProvider instance. model: Model instance. mode: Whether to 'eval' with metrics or create 'sample' s. save_dir: Path to directory to save summary events. restore_dir: Path to directory with checkpoints, defaults to save_dir. batch_size: Size of each eval/sample batch. num_batches: How many batches to eval from dataset. -1 denotes all batches. ckpt_delay_secs: Time to wait when a new checkpoint was not detected. run_once: Only run evaluation or sampling once. run_until_step: Run until we see a checkpoint with a step greater or equal to the specified value. Ignored if <= 0. Returns: If the mode is 'eval', then returns a dictionary of Tensors keyed by loss type. Otherwise, returns None. """ # Default to restoring from the save directory. restore_dir = save_dir if not restore_dir else restore_dir # Set up the summary writer and metrics. summary_dir = os.path.join(save_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) # Sample continuously and load the newest checkpoint each time checkpoints_iterator = tf.train.checkpoints_iterator( restore_dir, ckpt_delay_secs) # Get the dataset. dataset = data_provider.get_batch(batch_size=batch_size, shuffle=False, repeats=-1) # Get audio sample rate sample_rate = data_provider.sample_rate # Get feature frame rate frame_rate = data_provider.frame_rate latest_losses = None with summary_writer.as_default(): for checkpoint_path in checkpoints_iterator: step = int(checkpoint_path.split('-')[-1]) # Redefine thte dataset iterator each time to make deterministic. dataset_iter = iter(dataset) # Load model. model.restore(checkpoint_path) # Iterate through dataset and make predictions checkpoint_start_time = time.time() for batch_idx in range(1, num_batches + 1): try: start_time = time.time() logging.info('Predicting batch %d of size %d', batch_idx, batch_size) # Predict a batch of audio. batch = next(dataset_iter) # TODO(jesseengel): Find a way to add losses with training=False. audio = batch['audio'] audio_gen, losses = model(batch, return_losses=True, training=True) audio_gen = np.array(audio_gen) outputs = model.get_controls(batch, training=True) # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: loudness_metrics = metrics.LoudnessMetrics( sample_rate=sample_rate, frame_rate=frame_rate) f0_metrics = metrics.F0Metrics(sample_rate=sample_rate, frame_rate=frame_rate, name='f0_harm') f0_crepe_metrics = metrics.F0CrepeMetrics( sample_rate=sample_rate, frame_rate=frame_rate) avg_losses = { name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) for name in list(losses.keys()) } has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch) logging.info('Prediction took %.1f seconds', time.time() - start_time) if mode == 'sample': start_time = time.time() logging.info('Writing summmaries for batch %d', batch_idx) # Add audio. summaries.audio_summary(audio_gen, step, sample_rate, name='audio_generated') summaries.audio_summary(audio, step, sample_rate, name='audio_original') # Add plots. summaries.waveform_summary(audio, audio_gen, step) summaries.spectrogram_summary(audio, audio_gen, step) if has_f0: summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step, name='f0_harmonic') logging.info( 'Writing batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) elif mode == 'eval': start_time = time.time() logging.info('Calculating metrics for batch %d', batch_idx) loudness_metrics.update_state(batch, audio_gen) if has_f0: f0_metrics.update_state(batch, outputs['f0_hz']) else: f0_crepe_metrics.update_state(batch, audio_gen) # Loss. for k, v in losses.items(): avg_losses[k].update_state(v) logging.info( 'Metrics for batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) except tf.errors.OutOfRangeError: logging.info('End of dataset.') break logging.info('All %d batches in checkpoint took %.1f seconds', num_batches, time.time() - checkpoint_start_time) if mode == 'eval': loudness_metrics.flush(step) if has_f0: f0_metrics.flush(step) else: f0_crepe_metrics.flush(step) latest_losses = {} for k, metric in avg_losses.items(): latest_losses[k] = metric.result() tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step) metric.reset_states() summary_writer.flush() if run_once: break if 0 < run_until_step <= step: logging.info( 'Saw checkpoint with step %d, which is greater or equal to' ' `run_until_step` of %d. Exiting.', step, run_until_step) break return latest_losses