示例#1
0
def main(unused_argv):
    """Parse gin config and run ddsp training, evaluation, or sampling."""
    model_dir = os.path.expanduser(FLAGS.model_dir)
    parse_gin(model_dir)

    # Training.
    if FLAGS.mode == 'train':
        strategy = train_util.get_strategy(tpu=FLAGS.tpu, gpus=FLAGS.gpu)
        with strategy.scope():
            model = models.get_model()
            trainer = train_util.Trainer(model, strategy)

        train_util.train(data_provider=gin.REQUIRED,
                         trainer=trainer,
                         model_dir=model_dir)

    # Evaluation.
    elif FLAGS.mode == 'eval':
        model = models.get_model()
        delay_start()
        eval_util.evaluate(data_provider=gin.REQUIRED,
                           model=model,
                           model_dir=model_dir)

    # Sampling.
    elif FLAGS.mode == 'sample':
        model = models.get_model()
        delay_start()
        eval_util.sample(data_provider=gin.REQUIRED,
                         model=model,
                         model_dir=model_dir)
示例#2
0
def main(unused_argv):
    """Parse gin config and run ddsp training, evaluation, or sampling."""
    restore_dir = os.path.expanduser(FLAGS.restore_dir)
    save_dir = os.path.expanduser(FLAGS.save_dir)
    # If no separate restore directory is given, use the save directory.
    restore_dir = save_dir if not restore_dir else restore_dir
    logging.info('Restore Dir: %s', restore_dir)
    logging.info('Save Dir: %s', save_dir)

    gfile.makedirs(restore_dir)  # Only makes dirs if they don't exist.
    parse_gin(restore_dir)
    logging.info('Operative Gin Config:\n%s', gin.config.config_str())
    train_util.gin_register_keras_layers()

    if FLAGS.allow_memory_growth:
        allow_memory_growth()

    # Training.
    if FLAGS.mode == 'train':
        strategy = train_util.get_strategy(tpu=FLAGS.tpu,
                                           cluster_config=FLAGS.cluster_config)
        with strategy.scope():
            model = models.get_model()
            trainer = trainers.get_trainer_class()(model, strategy)

        train_util.train(data_provider=gin.REQUIRED,
                         trainer=trainer,
                         save_dir=save_dir,
                         restore_dir=restore_dir,
                         early_stop_loss_value=FLAGS.early_stop_loss_value,
                         report_loss_to_hypertune=FLAGS.hypertune,
                         stop_at_nan=FLAGS.stop_at_nan)

    # Evaluation.
    elif FLAGS.mode == 'eval':
        model = models.get_model()
        delay_start()
        eval_util.evaluate(data_provider=gin.REQUIRED,
                           model=model,
                           save_dir=save_dir,
                           restore_dir=restore_dir,
                           run_once=FLAGS.run_once)

    # Sampling.
    elif FLAGS.mode == 'sample':
        model = models.get_model()
        delay_start()
        eval_util.sample(data_provider=gin.REQUIRED,
                         model=model,
                         save_dir=save_dir,
                         restore_dir=restore_dir,
                         run_once=FLAGS.run_once)
示例#3
0
def run():
    """Parse gin config and run ddsp training, evaluation, or sampling."""
    model_dir = os.path.expanduser(FLAGS.model_dir)
    parse_gin(model_dir)
    model = models.get_model()

    # Training.
    if FLAGS.mode == 'train':
        train_util.train(data_provider=gin.REQUIRED,
                         model=model,
                         model_dir=model_dir,
                         num_steps=FLAGS.num_train_steps,
                         master=FLAGS.master,
                         use_tpu=FLAGS.use_tpu)

    # Evaluation.
    elif FLAGS.mode == 'eval':
        delay_start()
        eval_util.evaluate(data_provider=gin.REQUIRED,
                           model=model,
                           model_dir=model_dir,
                           master=FLAGS.master,
                           run_once=FLAGS.eval_once)

    # Sampling.
    elif FLAGS.mode == 'sample':
        delay_start()
        eval_util.sample(data_provider=gin.REQUIRED,
                         model=model,
                         model_dir=model_dir,
                         master=FLAGS.master,
                         run_once=FLAGS.eval_once)
示例#4
0
def main(unused_argv):
    """Parse gin config and run ddsp training, evaluation, or sampling."""
    restore_dir = os.path.expanduser(FLAGS.restore_dir)
    save_dir = os.path.expanduser(FLAGS.save_dir)
    # If no separate restore directory is given, use the save directory.
    restore_dir = save_dir if not restore_dir else restore_dir
    logging.info('Restore Dir: %s', restore_dir)
    logging.info('Save Dir: %s', save_dir)

    gfile.makedirs(restore_dir)  # Only makes dirs if they don't exist.
    parse_gin(restore_dir)

    if FLAGS.allow_memory_growth:
        allow_memory_growth()

    # Training.
    if FLAGS.mode == 'train':
        strategy = train_util.get_strategy(tpu=FLAGS.tpu, gpus=FLAGS.gpu)
        with strategy.scope():
            model = models.get_model()
            trainer = trainers.Trainer(model, strategy)

        train_util.train(data_provider=gin.REQUIRED,
                         trainer=trainer,
                         save_dir=save_dir,
                         restore_dir=restore_dir)

    # Evaluation.
    elif FLAGS.mode == 'eval':
        model = models.get_model()
        delay_start()
        eval_util.evaluate(data_provider=gin.REQUIRED,
                           model=model,
                           save_dir=save_dir,
                           restore_dir=restore_dir,
                           run_once=FLAGS.run_once)

    # Sampling.
    elif FLAGS.mode == 'sample':
        model = models.get_model()
        delay_start()
        eval_util.sample(data_provider=gin.REQUIRED,
                         model=model,
                         save_dir=save_dir,
                         restore_dir=restore_dir,
                         run_once=FLAGS.run_once)
示例#5
0
    def load_and_forward(self, audio):
        """Returns the features dict of forwarding the audio through the model

    Arguments:
      audio (filepath, dict): either a filepath or a feature dict with preprocessed audio

    Returns:
      features: dict 
    """
        if isinstance(audio, str):
            audio_features = load_audio_features(audio, max_secs=self.max_secs)
        elif isinstance(audio, tf.Tensor):
            audio_features = audio_features_from_wav(audio)
        else:
            audio_features = audio
        assert 'f0_hz' in audio_features.keys()

        n_samples = audio_features['audio'].shape[1]
        time_steps = int(self.time_steps_train * n_samples /
                         self.n_samples_train)
        z_time_steps = int(self.z_steps_train * n_samples /
                           self.n_samples_train)

        # -----------  Load Model for decoding ----------------
        gin_params = [
            'Harmonic.n_samples = {}'.format(n_samples),
            'FilteredNoise.n_samples = {}'.format(n_samples),
            'F0LoudnessPreprocessor.time_steps = {}'.format(time_steps),
            'F0NewLoudnessPreprocessor.time_steps = {}'.format(time_steps),
            'oscillator_bank.use_angular_cumsum = True',  # Avoids cumsum accumulation errors.
            # Encoders
            'MfccTimeConstantRnnEncoder.z_time_steps = {}'.format(z_time_steps
                                                                  ),
            # TODO {ZEncoder,ZF0Encoder}.audio_net.time_steps for all audio_nets
            'MfccTimeAverageRnnEncoder.z_time_steps = {}'.format(z_time_steps),
            'SpectralNet.t_steps = {}'.format(z_time_steps),
            'DilatedConvNet.t_steps = {}'.format(z_time_steps),
        ]

        with gin.unlock_config():
            gin.parse_config(gin_params)

        # Set up the model just to predict audio given new conditioning
        start_time = time.time()
        model = get_model()
        model.restore(self.ckpt)

        # Build model by running a batch through it.
        for key in ['f0_hz', 'f0_confidence', 'loudness_db']:
            audio_features[key] = audio_features[key][:time_steps]
        audio_features['audio'] = audio_features['audio'][:, :n_samples]
        out = model(audio_features, training=False)

        print('Restoring model took %.1f seconds' % (time.time() - start_time))
        return model, out
示例#6
0
文件: tuned.py 项目: nielsrolf/ddsp
def get_trained_model(model_dir):
  gin_file = os.path.join(model_dir, "operative_config-0.gin")
  # Parse gin config,
  with gin.unlock_config():
    gin.parse_config_file(gin_file, skip_unknown=True)

  ckpt_files = [f for f in tf.io.gfile.listdir(model_dir) if "ckpt" in f]
  step_of = lambda f: int(f.split(".")[0].split("-")[1])
  latest = max([step_of(f) for f in ckpt_files])
  ckpt_name = [i for i in ckpt_files if step_of(i) == latest][0].split(".")[0]
  ckpt = os.path.join(model_dir, ckpt_name)

  model = get_model()
  model.restore(ckpt)
  return model, ckpt
示例#7
0
def prepare_complete_tfrecord(dataset_dir='nsynth_guitar',
                              split='train',
                              sample_rate=16000,
                              frame_rate=250):
    split_dir = os.path.join(dataset_dir, split)
    operative_config_file = train_util.get_latest_operative_config(split_dir)
    partial_tfrecord_file = os.path.join(split_dir, 'partial.tfrecord')
    complete_tfrecord_file = os.path.join(split_dir, 'complete.tfrecord')

    with gin.unlock_config():
        if tf.io.gfile.exists(operative_config_file):
            gin.parse_config_file(operative_config_file, skip_unknown=True)

    data_provider = PartialTFRecordProvider(
        file_pattern=partial_tfrecord_file + '*',
        example_secs=4,
        sample_rate=sample_rate,
        frame_rate=frame_rate)

    dataset = data_provider.get_batch(1, shuffle=False, repeats=1)

    strategy = train_util.get_strategy()
    with strategy.scope():
        model = models.get_model()
        trainer = trainers.Trainer(model, strategy)
        trainer.restore(split_dir)

        # steps = dataset.cardinality()

        with tf.io.TFRecordWriter(complete_tfrecord_file) as writer:
            for step, e in enumerate(dataset):
                start_time = time.perf_counter()

                sample_name = e['sample_name'][0].numpy()
                note_number = e['note_number'][0].numpy()
                velocity = e['velocity'][0].numpy()
                instrument_source = e['instrument_source'][0].numpy()
                qualities = e['qualities'][0].numpy()
                audio = e['audio'][0].numpy()
                f0_hz = e['f0_hz'][0].numpy()
                f0_confidence = e['f0_confidence'][0].numpy()
                loudness_db = e['loudness_db'][0].numpy()

                complete_dataset_dict = {
                    'sample_name': _byte_feature(sample_name),
                    'note_number': _int64_feature(note_number),
                    'velocity': _int64_feature(velocity),
                    'instrument_source': _int64_feature(instrument_source),
                    'qualities': _int64_feature(qualities),
                    'audio': _float_feature(audio),
                    'f0_hz': _float_feature(f0_hz),
                    'f0_confidence': _float_feature(f0_confidence),
                    'loudness_db': _float_feature(loudness_db),
                }

                e = model.encode(e, training=False)

                f0_scaled = tf.squeeze(e['f0_scaled']).numpy()
                ld_scaled = tf.squeeze(e['ld_scaled']).numpy()
                z = tf.reshape(e['z'], shape=(-1)).numpy()

                complete_dataset_dict.update({
                    'f0_scaled': _float_feature(f0_scaled),
                    'ld_scaled': _float_feature(ld_scaled),
                    'z': _float_feature(z),
                })

                tf_example = tf.train.Example(
                    features=tf.train.Features(feature=complete_dataset_dict))

                writer.write(tf_example.SerializeToString())

                stop_time = time.perf_counter()
                elapsed_time = stop_time - start_time
                print('{} - sample_name: {} - elapsed_time: {:.3f}'.format(
                    step+1, e['sample_name'], elapsed_time))