コード例 #1
0
ファイル: ddsp_run.py プロジェクト: zhufengGNSS/ddsp
def main(unused_argv):
    """Parse gin config and run ddsp training, evaluation, or sampling."""
    restore_dir = os.path.expanduser(FLAGS.restore_dir)
    save_dir = os.path.expanduser(FLAGS.save_dir)
    # If no separate restore directory is given, use the save directory.
    restore_dir = save_dir if not restore_dir else restore_dir
    logging.info('Restore Dir: %s', restore_dir)
    logging.info('Save Dir: %s', save_dir)

    gfile.makedirs(restore_dir)  # Only makes dirs if they don't exist.
    parse_gin(restore_dir)
    train_util.gin_register_keras_layers()

    if FLAGS.allow_memory_growth:
        allow_memory_growth()

    # Training.
    if FLAGS.mode == 'train':
        strategy = train_util.get_strategy(tpu=FLAGS.tpu,
                                           cluster_config=FLAGS.cluster_config)
        with strategy.scope():
            model = models.get_model()
            trainer = trainers.Trainer(model, strategy)

        train_util.train(data_provider=gin.REQUIRED,
                         trainer=trainer,
                         save_dir=save_dir,
                         restore_dir=restore_dir,
                         early_stop_loss_value=FLAGS.early_stop_loss_value,
                         report_loss_to_hypertune=FLAGS.hypertune)

    # Evaluation.
    elif FLAGS.mode == 'eval':
        model = models.get_model()
        delay_start()
        eval_util.evaluate(data_provider=gin.REQUIRED,
                           model=model,
                           save_dir=save_dir,
                           restore_dir=restore_dir,
                           run_once=FLAGS.run_once)

    # Sampling.
    elif FLAGS.mode == 'sample':
        model = models.get_model()
        delay_start()
        eval_util.sample(data_provider=gin.REQUIRED,
                         model=model,
                         save_dir=save_dir,
                         restore_dir=restore_dir,
                         run_once=FLAGS.run_once)
コード例 #2
0
ファイル: ddsp_run.py プロジェクト: tproost/ddsp
def main(unused_argv):
  """Parse gin config and run ddsp training, evaluation, or sampling."""
  restore_dir = os.path.expanduser(FLAGS.restore_dir)
  save_dir = os.path.expanduser(FLAGS.save_dir)
  # If no separate restore directory is given, use the save directory.
  restore_dir = save_dir if not restore_dir else restore_dir
  logging.info('Restore Dir: %s', restore_dir)
  logging.info('Save Dir: %s', save_dir)

  parse_gin(restore_dir)
  if FLAGS.allow_memory_growth:
    allow_memory_growth()

  # Training.
  if FLAGS.mode == 'train':
    strategy = train_util.get_strategy(tpu=FLAGS.tpu, gpus=FLAGS.gpu)
    with strategy.scope():
      model = models.get_model()
      trainer = trainers.Trainer(model, strategy)

    train_util.train(data_provider=gin.REQUIRED,
                     trainer=trainer,
                     save_dir=save_dir,
                     restore_dir=restore_dir)

  # Evaluation.
  elif FLAGS.mode == 'eval':
    model = models.get_model()
    delay_start()
    eval_util.evaluate(data_provider=gin.REQUIRED,
                       model=model,
                       save_dir=save_dir,
                       restore_dir=restore_dir,
                       run_once=FLAGS.run_once)

  # Sampling.
  elif FLAGS.mode == 'sample':
    model = models.get_model()
    delay_start()
    eval_util.sample(data_provider=gin.REQUIRED,
                     model=model,
                     save_dir=save_dir,
                     restore_dir=restore_dir,
                     run_once=FLAGS.run_once)
コード例 #3
0
                                                 name='processor_group')


# Loss_functions
spectral_loss = ddsp.losses.SpectralLoss(loss_type='L1',
                                         mag_weight=1.0,
                                         logmag_weight=1.0)

with strategy.scope():
  # Put it together in a model.
  model = models.Autoencoder(preprocessor=preprocessor,
                             encoder=None,
                             decoder=decoder,
                             processor_group=processor_group,
                             losses=[spectral_loss])
  trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)

"""## or [`gin`](https://github.com/google/gin-config)"""

gin_string = """
import ddsp
import ddsp.training

# Preprocessor
models.Autoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
preprocessing.F0LoudnessPreprocessor.time_steps = 1000


# Encoder
models.Autoencoder.encoder = None
コード例 #4
0
def prepare_complete_tfrecord(dataset_dir='nsynth_guitar',
                              split='train',
                              sample_rate=16000,
                              frame_rate=250):
    split_dir = os.path.join(dataset_dir, split)
    operative_config_file = train_util.get_latest_operative_config(split_dir)
    partial_tfrecord_file = os.path.join(split_dir, 'partial.tfrecord')
    complete_tfrecord_file = os.path.join(split_dir, 'complete.tfrecord')

    with gin.unlock_config():
        if tf.io.gfile.exists(operative_config_file):
            gin.parse_config_file(operative_config_file, skip_unknown=True)

    data_provider = PartialTFRecordProvider(
        file_pattern=partial_tfrecord_file + '*',
        example_secs=4,
        sample_rate=sample_rate,
        frame_rate=frame_rate)

    dataset = data_provider.get_batch(1, shuffle=False, repeats=1)

    strategy = train_util.get_strategy()
    with strategy.scope():
        model = models.get_model()
        trainer = trainers.Trainer(model, strategy)
        trainer.restore(split_dir)

        # steps = dataset.cardinality()

        with tf.io.TFRecordWriter(complete_tfrecord_file) as writer:
            for step, e in enumerate(dataset):
                start_time = time.perf_counter()

                sample_name = e['sample_name'][0].numpy()
                note_number = e['note_number'][0].numpy()
                velocity = e['velocity'][0].numpy()
                instrument_source = e['instrument_source'][0].numpy()
                qualities = e['qualities'][0].numpy()
                audio = e['audio'][0].numpy()
                f0_hz = e['f0_hz'][0].numpy()
                f0_confidence = e['f0_confidence'][0].numpy()
                loudness_db = e['loudness_db'][0].numpy()

                complete_dataset_dict = {
                    'sample_name': _byte_feature(sample_name),
                    'note_number': _int64_feature(note_number),
                    'velocity': _int64_feature(velocity),
                    'instrument_source': _int64_feature(instrument_source),
                    'qualities': _int64_feature(qualities),
                    'audio': _float_feature(audio),
                    'f0_hz': _float_feature(f0_hz),
                    'f0_confidence': _float_feature(f0_confidence),
                    'loudness_db': _float_feature(loudness_db),
                }

                e = model.encode(e, training=False)

                f0_scaled = tf.squeeze(e['f0_scaled']).numpy()
                ld_scaled = tf.squeeze(e['ld_scaled']).numpy()
                z = tf.reshape(e['z'], shape=(-1)).numpy()

                complete_dataset_dict.update({
                    'f0_scaled': _float_feature(f0_scaled),
                    'ld_scaled': _float_feature(ld_scaled),
                    'z': _float_feature(z),
                })

                tf_example = tf.train.Example(
                    features=tf.train.Features(feature=complete_dataset_dict))

                writer.write(tf_example.SerializeToString())

                stop_time = time.perf_counter()
                elapsed_time = stop_time - start_time
                print('{} - sample_name: {} - elapsed_time: {:.3f}'.format(
                    step+1, e['sample_name'], elapsed_time))