def main(unused_argv): """Saves bundle or runs generator based on flags.""" tf.logging.set_verbosity(FLAGS.log) bundle = get_bundle() if bundle: config_id = bundle.generator_details.id config = drums_rnn_model.default_configs[config_id] config.hparams.parse(FLAGS.hparams) else: config = drums_rnn_config_flags.config_from_flags() # Having too large of a batch size will slow generation down unnecessarily. config.hparams.batch_size = min(config.hparams.batch_size, FLAGS.beam_size * FLAGS.branch_factor) generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator( model=drums_rnn_model.DrumsRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=bundle) if FLAGS.save_generator_bundle: bundle_filename = os.path.expanduser(FLAGS.bundle_file) if FLAGS.bundle_description is None: tf.logging.warning('No bundle description provided.') tf.logging.info('Saving generator bundle to %s', bundle_filename) generator.create_bundle_file(bundle_filename, FLAGS.bundle_description) else: run_with_flags(generator)
def get_generator_map(): """Returns a map from the generator ID to its SequenceGenerator class. Binds the `config` argument so that the constructor matches the BaseSequenceGenerator class. Returns: Map from the generator ID to its SequenceGenerator class with a bound `config` argument. """ return {key: partial(DrumsRnnSequenceGenerator, drums_rnn_model.DrumsRnnModel(config), config.details) for (key, config) in drums_rnn_model.default_configs.items()}
def main(unused_argv): """Saves bundle or runs generator based on flags.""" tf.logging.set_verbosity(FLAGS.log) config = drums_rnn_config_flags.config_from_flags() generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator( model=drums_rnn_model.DrumsRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=get_bundle()) if FLAGS.save_generator_bundle: bundle_filename = os.path.expanduser(FLAGS.bundle_file) if FLAGS.bundle_description is None: tf.logging.warning('No bundle description provided.') tf.logging.info('Saving generator bundle to %s', bundle_filename) generator.create_bundle_file(bundle_filename, FLAGS.bundle_description) else: run_with_flags(generator)
def create_sequence_generator(config, **kwargs): return DrumsRnnSequenceGenerator( drums_rnn_model.DrumsRnnModel(config), config.details, steps_per_quarter=config.steps_per_quarter, **kwargs)
def create_sequence_generator(config, **kwargs): return DrumsRnnSequenceGenerator(drums_rnn_model.DrumsRnnModel(config), config.details, **kwargs)
def generate_drums(): """Generate a new drum groove by querying the model.""" global drums_bundle global generated_drums global playable_notes global seed_drum_sequence global num_steps global qpm global total_seconds global temperature drums_config_id = drums_bundle.generator_details.id drums_config = drums_rnn_model.default_configs[drums_config_id] generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator( model=drums_rnn_model.DrumsRnnModel(drums_config), details=drums_config.details, steps_per_quarter=drums_config.steps_per_quarter, checkpoint=melody_rnn_generate.get_checkpoint(), bundle=drums_bundle) generator_options = generator_pb2.GeneratorOptions() generator_options.args['temperature'].float_value = temperature generator_options.args['beam_size'].int_value = 1 generator_options.args['branch_factor'].int_value = 1 generator_options.args['steps_per_iteration'].int_value = 1 if seed_drum_sequence is None: primer_drums = magenta.music.DrumTrack([frozenset([36])]) primer_sequence = primer_drums.to_sequence(qpm=qpm) local_num_steps = num_steps else: primer_sequence = seed_drum_sequence local_num_steps = num_steps * 2 tempo = primer_sequence.tempos.add() tempo.qpm = qpm step_length = 60. / qpm / 4.0 total_seconds = local_num_steps * step_length # Set the start time to begin on the next step after the last note ends. last_end_time = (max( n.end_time for n in primer_sequence.notes) if primer_sequence.notes else 0) generator_options.generate_sections.add(start_time=last_end_time + step_length, end_time=total_seconds) generated_sequence = generator.generate(primer_sequence, generator_options) generated_sequence = sequences_lib.quantize_note_sequence( generated_sequence, 4) if seed_drum_sequence is not None: i = 0 while i < len(generated_sequence.notes): if generated_sequence.notes[i].quantized_start_step < num_steps: del generated_sequence.notes[i] else: generated_sequence.notes[i].quantized_start_step -= num_steps generated_sequence.notes[i].quantized_end_step -= num_steps i += 1 drum_pattern = [(n.pitch, n.quantized_start_step, n.quantized_end_step) for n in generated_sequence.notes] # First clear the last drum pattern. if len(playable_notes) > 0: playable_notes = SortedList( [x for x in playable_notes if x.type != 'drums'], key=lambda x: x.onset) for p, s, e in drum_pattern: playable_notes.add( PlayableNote(type='drums', note=[], instrument=DRUM_MAPPING[p], onset=s))
bundle_file = os.path.expanduser("drum_kit_rnn.mag") bundle = magenta.music.read_bundle_file(bundle_file) config_id = bundle.generator_details.id config = drums_rnn_model.default_configs[config_id] beam_size = 1 branch_factor = 1 config.hparams.batch_size = min(config.hparams.batch_size, beam_size * branch_factor) generator = drums_rnn_sequence_generator.DrumsRnnSequenceGenerator( model=drums_rnn_model.DrumsRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, bundle=bundle) example = "[(36,45), (), (36,), (), (36,), (36,), (), (36,), (36,46,), (45,), (36,46,), ()]" def generate(primer=example, qpm=120, num_steps=120, temperature=1, branch_factor=1, beam_size=2, steps_per_iteration=1):