Example #1
0
def reconstruct(model_dir, files_pattern, fine_tune=False):
    """Use the specified model and the TrainedModelFileAPI to reconstruct files
  """
    summary_dir = os.path.join(model_dir, 'summaries', 'examples')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    latest_checkpoint = train_util.get_latest_chekpoint(model_dir)
    step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", ""))
    unskewed_loss = UnskewedSpectralLoss(fft_sizes=(2048, 1024, 512, 256, 128,
                                                    64),
                                         loss_type='L1',
                                         log_smooth=[None, 1, 10, 100])
    with summary_writer.as_default():
        file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI
        print(file_api_cls)
        model = file_api_cls(model_dir)
        for audio_file in glob(files_pattern):
            original, reconstruction = model.reconstruct(audio_file)
            loss = unskewed_loss(original, reconstruction)
            summaries.audio_summary(reconstruction,
                                    audio_type="reconstruction",
                                    step=step,
                                    fine_tune=fine_tune,
                                    model=model_dir,
                                    audio_file=audio_file,
                                    reconstruction_loss=loss)
            summaries.error_heatmap_summary(original,
                                            reconstruction,
                                            audio_type="reconstruction",
                                            step=step,
                                            fine_tune=fine_tune,
                                            model=model_dir,
                                            audio_file=audio_file,
                                            reconstruction_loss=loss)
Example #2
0
def transfer(model_dir,
             timbre_files_pattern,
             melody_files_pattern,
             fine_tune=False):
    """Use the specified model and the TrainedModelFileAPI to perform timbre transfer
  """
    summary_dir = os.path.join(model_dir, 'summaries', 'examples')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    latest_checkpoint = train_util.get_latest_chekpoint(model_dir)
    step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", ""))
    print("Melody files: ", glob(melody_files_pattern))
    print("Timbre files: ", glob(timbre_files_pattern))
    with summary_writer.as_default():
        file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI
        model = file_api_cls(model_dir)
        for melody in glob(melody_files_pattern):
            for timbre in glob(timbre_files_pattern):
                audio = model.transfer(melody, timbre)
                summaries.audio_summary(audio,
                                        step=step,
                                        audio_type="transfer",
                                        fine_tune=fine_tune,
                                        melody=melody,
                                        timbre=timbre,
                                        model=model_dir)
Example #3
0
 def restore(self, checkpoint_path):
   """Restore model and optimizer from a checkpoint."""
   start_time = time.time()
   latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path)
   if latest_checkpoint is not None:
     checkpoint = tf.train.Checkpoint(model=self)
     checkpoint.restore(latest_checkpoint).expect_partial()
     logging.info('Loaded checkpoint %s', latest_checkpoint)
     logging.info('Loading model took %.1f seconds', time.time() - start_time)
   else:
     logging.info('Could not find checkpoint to load at %s, skipping.',
                  checkpoint_path)
Example #4
0
def cycle_reconstruct(model_dir,
                      timbre_files_pattern,
                      melody_files_pattern,
                      fine_tune=False):
    """Use the specified model and the TrainedModelFileAPI to perform timbre interpolation
  Each melody is used as starting timbre with each timbre file as final timbre
  """
    summary_dir = os.path.join(model_dir, 'summaries', 'cycle')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    latest_checkpoint = train_util.get_latest_chekpoint(model_dir)
    step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", ""))
    unskewed_loss = UnskewedSpectralLoss(fft_sizes=(2048, 1024, 512, 256, 128,
                                                    64),
                                         loss_type='L1',
                                         log_smooth=[None, 1, 10, 100])
    with summary_writer.as_default():
        model = TrainedModelFileAPI(model_dir)
        for melody in glob(melody_files_pattern) + [
                'samples/guitar/AR_Lick4_FN.wav'
        ]:
            for timbre in glob(timbre_files_pattern):
                original_audio, cycled = model.cycle_reconstruct(
                    audio=timbre, intermediate_melody=melody)
                loss = unskewed_loss(original_audio, cycled)
                summaries.audio_summary(cycled,
                                        audio_type="cycled",
                                        audio_file=timbre,
                                        intermediate=melody,
                                        step=step,
                                        cycle_reconstruction_loss=loss,
                                        model=model_dir,
                                        fine_tune=fine_tune)

                summaries.error_heatmap_summary(original_audio,
                                                cycled,
                                                audio_type="cycled",
                                                audio_file=timbre,
                                                intermediate=melody,
                                                step=step,
                                                cycle_reconstruction_loss=loss,
                                                model=model_dir,
                                                fine_tune=fine_tune)
Example #5
0
    def restore(self, checkpoint_path, restore_keys=None):
        """Restore model and optimizer from a checkpoint if it exists."""
        logging.info('Restoring from checkpoint...')
        start_time = time.time()

        # Prefer function args over object properties.
        restore_keys = restore_keys or self.restore_keys
        if restore_keys is None:
            # If no keys are passed, restore the whole model.
            model = self.model
            logging.info('Trainer restoring the full model')
        else:
            # Restore only sub-modules by building a new subgraph.
            restore_dict = {k: getattr(self.model, k) for k in restore_keys}
            model = tf.train.Checkpoint(**restore_dict)

            logging.info('Trainer restoring model subcomponents:')
            for k, v in restore_dict.items():
                log_str = 'Restoring {}: {}'.format(k, v)
                logging.info(log_str)

        # Restore from latest checkpoint.
        checkpoint = self.get_checkpoint(model)
        latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path)
        if latest_checkpoint is not None:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            with self.strategy.scope():
                if restore_keys is None:
                    checkpoint.restore(latest_checkpoint)
                else:
                    checkpoint.restore(latest_checkpoint).expect_partial()
                logging.info('Loaded checkpoint %s', latest_checkpoint)
            logging.info('Loading model took %.1f seconds',
                         time.time() - start_time)
        else:
            logging.info('No checkpoint, skipping.')
Example #6
0
def interpolate(model_dir,
                timbre_files_pattern,
                melody_files_pattern,
                fine_tune=False):
    """Use the specified model and the TrainedModelFileAPI to perform timbre interpolation
  Each melody is used as starting timbre with each timbre file as final timbre
  """
    summary_dir = os.path.join(model_dir, 'summaries', 'examples')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    latest_checkpoint = train_util.get_latest_chekpoint(model_dir)
    step = int(latest_checkpoint.split("/")[-1].replace("ckpt-", ""))
    with summary_writer.as_default():
        file_api_cls = TunedAEFileApi if fine_tune else TrainedModelFileAPI
        model = file_api_cls(model_dir)
        for melody in glob(melody_files_pattern):
            for timbre in glob(timbre_files_pattern):
                audio = model.continuous_interpolation(melody, timbre)
                summaries.audio_summary(audio,
                                        audio_type="interpolation",
                                        step=step,
                                        fine_tune=fine_tune,
                                        model=model_dir,
                                        melody=melody,
                                        timbre=timbre)