def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    bundle = get_bundle()

    if bundle:
        config_id = bundle.generator_details.id
        config = drums_rnn_model.default_configs[config_id]
        config.hparams.parse(FLAGS.hparams)
    else:
        config = drums_rnn_config_flags.config_from_flags()
    # Having too large of a batch size will slow generation down unnecessarily.
    config.hparams.batch_size = min(config.hparams.batch_size,
                                    FLAGS.beam_size * FLAGS.branch_factor)

    generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
        model=drums_rnn_model.DrumsRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=bundle)

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
def get_generator_map():
  """Returns a map from the generator ID to its SequenceGenerator class.

  Binds the `config` argument so that the constructor matches the
  BaseSequenceGenerator class.

  Returns:
    Map from the generator ID to its SequenceGenerator class with a bound
    `config` argument.
  """
  return {key: partial(DrumsRnnSequenceGenerator,
                       drums_rnn_model.DrumsRnnModel(config), config.details)
          for (key, config) in drums_rnn_model.default_configs.items()}
def main(unused_argv):
    """Saves bundle or runs generator based on flags."""
    tf.logging.set_verbosity(FLAGS.log)

    config = drums_rnn_config_flags.config_from_flags()
    generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
        model=drums_rnn_model.DrumsRnnModel(config),
        details=config.details,
        steps_per_quarter=config.steps_per_quarter,
        checkpoint=get_checkpoint(),
        bundle=get_bundle())

    if FLAGS.save_generator_bundle:
        bundle_filename = os.path.expanduser(FLAGS.bundle_file)
        if FLAGS.bundle_description is None:
            tf.logging.warning('No bundle description provided.')
        tf.logging.info('Saving generator bundle to %s', bundle_filename)
        generator.create_bundle_file(bundle_filename, FLAGS.bundle_description)
    else:
        run_with_flags(generator)
Exemple #4
0
 def create_sequence_generator(config, **kwargs):
     return DrumsRnnSequenceGenerator(
         drums_rnn_model.DrumsRnnModel(config),
         config.details,
         steps_per_quarter=config.steps_per_quarter,
         **kwargs)
 def create_sequence_generator(config, **kwargs):
     return DrumsRnnSequenceGenerator(drums_rnn_model.DrumsRnnModel(config),
                                      config.details, **kwargs)
Exemple #6
0
def generate_drums():
    """Generate a new drum groove by querying the model."""
    global drums_bundle
    global generated_drums
    global playable_notes
    global seed_drum_sequence
    global num_steps
    global qpm
    global total_seconds
    global temperature
    drums_config_id = drums_bundle.generator_details.id
    drums_config = drums_rnn_model.default_configs[drums_config_id]
    generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
        model=drums_rnn_model.DrumsRnnModel(drums_config),
        details=drums_config.details,
        steps_per_quarter=drums_config.steps_per_quarter,
        checkpoint=melody_rnn_generate.get_checkpoint(),
        bundle=drums_bundle)
    generator_options = generator_pb2.GeneratorOptions()
    generator_options.args['temperature'].float_value = temperature
    generator_options.args['beam_size'].int_value = 1
    generator_options.args['branch_factor'].int_value = 1
    generator_options.args['steps_per_iteration'].int_value = 1
    if seed_drum_sequence is None:
        primer_drums = magenta.music.DrumTrack([frozenset([36])])
        primer_sequence = primer_drums.to_sequence(qpm=qpm)
        local_num_steps = num_steps
    else:
        primer_sequence = seed_drum_sequence
        local_num_steps = num_steps * 2
        tempo = primer_sequence.tempos.add()
        tempo.qpm = qpm
    step_length = 60. / qpm / 4.0
    total_seconds = local_num_steps * step_length
    # Set the start time to begin on the next step after the last note ends.
    last_end_time = (max(
        n.end_time
        for n in primer_sequence.notes) if primer_sequence.notes else 0)
    generator_options.generate_sections.add(start_time=last_end_time +
                                            step_length,
                                            end_time=total_seconds)
    generated_sequence = generator.generate(primer_sequence, generator_options)
    generated_sequence = sequences_lib.quantize_note_sequence(
        generated_sequence, 4)
    if seed_drum_sequence is not None:
        i = 0
        while i < len(generated_sequence.notes):
            if generated_sequence.notes[i].quantized_start_step < num_steps:
                del generated_sequence.notes[i]
            else:
                generated_sequence.notes[i].quantized_start_step -= num_steps
                generated_sequence.notes[i].quantized_end_step -= num_steps
                i += 1
    drum_pattern = [(n.pitch, n.quantized_start_step, n.quantized_end_step)
                    for n in generated_sequence.notes]
    # First clear the last drum pattern.
    if len(playable_notes) > 0:
        playable_notes = SortedList(
            [x for x in playable_notes if x.type != 'drums'],
            key=lambda x: x.onset)
    for p, s, e in drum_pattern:
        playable_notes.add(
            PlayableNote(type='drums',
                         note=[],
                         instrument=DRUM_MAPPING[p],
                         onset=s))
Exemple #7
0

bundle_file = os.path.expanduser("drum_kit_rnn.mag")
bundle = magenta.music.read_bundle_file(bundle_file)

config_id = bundle.generator_details.id
config = drums_rnn_model.default_configs[config_id]

beam_size = 1
branch_factor = 1

config.hparams.batch_size = min(config.hparams.batch_size,
                                beam_size * branch_factor)

generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator(
    model=drums_rnn_model.DrumsRnnModel(config),
    details=config.details,
    steps_per_quarter=config.steps_per_quarter,
    bundle=bundle)

example = "[(36,45), (), (36,), (), (36,), (36,), (), (36,), (36,46,), (45,), (36,46,), ()]"


def generate(primer=example,
             qpm=120,
             num_steps=120,
             temperature=1,
             branch_factor=1,
             beam_size=2,
             steps_per_iteration=1):