예제 #1
0
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    bundle = get_bundle()

    config_id = bundle.generator_details.id if bundle else FLAGS.config
    config = polyphony_model.default_configs[config_id]
    config.hparams.parse(FLAGS.hparams)
    # Having too large of a batch size will slow generation down unnecessarily.
    config.hparams.batch_size = min(config.hparams.batch_size,
                                    FLAGS.beam_size * FLAGS.branch_factor)

    generator = polyphony_sequence_generator.PolyphonyRnnSequenceGenerator(
        model=polyphony_model.PolyphonyRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=bundle)

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    config = polyphony_model.default_configs[FLAGS.config]

    generator = polyphony_sequence_generator.PolyphonyRnnSequenceGenerator(
        model=polyphony_model.PolyphonyRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=get_bundle())

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
예제 #3
0
    def __init__(self, led):
        self.button = Button()
        self.player = Player(MusicGeneratorSettings.output_dir)
        self.led = led

        bundle_file = os.path.expanduser(MusicGeneratorSettings.bundle_file)
        bundle = sequence_generator_bundle.read_bundle_file(bundle_file)
        tf.logging.set_verbosity(MusicGeneratorSettings.log)

        config_id = bundle.generator_details.id
        config = polyphony_model.default_configs[config_id]
        config.hparams.parse(MusicGeneratorSettings.hparams)

        # Having too large of a batch size will slow generation down unnecessarily.
        config.hparams.batch_size = min(
            config.hparams.batch_size, MusicGeneratorSettings.beam_size *
            MusicGeneratorSettings.branch_factor)

        self.generator = polyphony_sequence_generator.PolyphonyRnnSequenceGenerator(
            model=polyphony_model.PolyphonyRnnModel(config),
            details=config.details,
            steps_per_quarter=config.steps_per_quarter,
            checkpoint=None,
            bundle=bundle)
예제 #4
0
 def create_sequence_generator(config, **kwargs):
     return PolyphonyRnnSequenceGenerator(
         polyphony_model.PolyphonyRnnModel(config),
         config.details,
         steps_per_quarter=config.steps_per_quarter,
         **kwargs)
예제 #5
0
 def create_sequence_generator(config, **kwargs):
     return PolyphonyRnnSequenceGenerator(
         polyphony_model.PolyphonyRnnModel(config), config.details,
         **kwargs)