Example #1
0
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with special streaming parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set streaming specific params.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
         n_samples = gin.query_parameter('Harmonic.n_samples')
         samples_per_frame = int(n_samples / time_steps)
         gin.parse_config([
             'F0PowerPreprocessor.time_steps=1',
             f'Harmonic.n_samples={samples_per_frame}',
             f'FilteredNoise.n_samples={samples_per_frame}',
             pg_string,
         ])
Example #2
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))
        eval_default = 'eval/basic.gin'
        gin.parse_config_file(eval_default)

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            logging.info('Using operative config: %s', operative_config)
            operative_config = cloud.make_file_paths_local(
                operative_config, GIN_PATH)
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
        gin.parse_config_files_and_bindings(gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
Example #3
0
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with new length parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set gin params to new length.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         gin.parse_config([
             'Harmonic.n_samples=%d' % self.n_samples,
             'FilteredNoise.n_samples=%d' % self.n_samples,
             'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps,
             'oscillator_bank.use_angular_cumsum=True',
             pg_string,
         ])
Example #4
0
 def parse_gin_config(self, ckpt):
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
Example #5
0
def parse_operative_config(ckpt_dir):
  with gin.unlock_config():
    operative_config = train_util.get_latest_operative_config(ckpt_dir)
    print(f'Parsing from operative_config {operative_config}')
    gin.parse_config_file(operative_config, skip_unknown=True)
Example #6
0
def prepare_complete_tfrecord(dataset_dir='nsynth_guitar',
                              split='train',
                              sample_rate=16000,
                              frame_rate=250):
    split_dir = os.path.join(dataset_dir, split)
    operative_config_file = train_util.get_latest_operative_config(split_dir)
    partial_tfrecord_file = os.path.join(split_dir, 'partial.tfrecord')
    complete_tfrecord_file = os.path.join(split_dir, 'complete.tfrecord')

    with gin.unlock_config():
        if tf.io.gfile.exists(operative_config_file):
            gin.parse_config_file(operative_config_file, skip_unknown=True)

    data_provider = PartialTFRecordProvider(
        file_pattern=partial_tfrecord_file + '*',
        example_secs=4,
        sample_rate=sample_rate,
        frame_rate=frame_rate)

    dataset = data_provider.get_batch(1, shuffle=False, repeats=1)

    strategy = train_util.get_strategy()
    with strategy.scope():
        model = models.get_model()
        trainer = trainers.Trainer(model, strategy)
        trainer.restore(split_dir)

        # steps = dataset.cardinality()

        with tf.io.TFRecordWriter(complete_tfrecord_file) as writer:
            for step, e in enumerate(dataset):
                start_time = time.perf_counter()

                sample_name = e['sample_name'][0].numpy()
                note_number = e['note_number'][0].numpy()
                velocity = e['velocity'][0].numpy()
                instrument_source = e['instrument_source'][0].numpy()
                qualities = e['qualities'][0].numpy()
                audio = e['audio'][0].numpy()
                f0_hz = e['f0_hz'][0].numpy()
                f0_confidence = e['f0_confidence'][0].numpy()
                loudness_db = e['loudness_db'][0].numpy()

                complete_dataset_dict = {
                    'sample_name': _byte_feature(sample_name),
                    'note_number': _int64_feature(note_number),
                    'velocity': _int64_feature(velocity),
                    'instrument_source': _int64_feature(instrument_source),
                    'qualities': _int64_feature(qualities),
                    'audio': _float_feature(audio),
                    'f0_hz': _float_feature(f0_hz),
                    'f0_confidence': _float_feature(f0_confidence),
                    'loudness_db': _float_feature(loudness_db),
                }

                e = model.encode(e, training=False)

                f0_scaled = tf.squeeze(e['f0_scaled']).numpy()
                ld_scaled = tf.squeeze(e['ld_scaled']).numpy()
                z = tf.reshape(e['z'], shape=(-1)).numpy()

                complete_dataset_dict.update({
                    'f0_scaled': _float_feature(f0_scaled),
                    'ld_scaled': _float_feature(ld_scaled),
                    'z': _float_feature(z),
                })

                tf_example = tf.train.Example(
                    features=tf.train.Features(feature=complete_dataset_dict))

                writer.write(tf_example.SerializeToString())

                stop_time = time.perf_counter()
                elapsed_time = stop_time - start_time
                print('{} - sample_name: {} - elapsed_time: {:.3f}'.format(
                    step+1, e['sample_name'], elapsed_time))