def _generate(self, input_sequence, generator_options): if len(generator_options.input_sections) > 1: raise sequence_generator.SequenceGeneratorError( 'This model supports at most one input_sections message, but got %s' % len(generator_options.input_sections)) if len(generator_options.generate_sections) != 1: raise sequence_generator.SequenceGeneratorError( 'This model supports only 1 generate_sections message, but got %s' % len(generator_options.generate_sections)) if input_sequence and input_sequence.tempos: qpm = input_sequence.tempos[0].qpm else: qpm = note_seq.DEFAULT_QUARTERS_PER_MINUTE steps_per_second = note_seq.steps_per_quarter_to_steps_per_second( self.steps_per_quarter, qpm) generate_section = generator_options.generate_sections[0] if generator_options.input_sections: input_section = generator_options.input_sections[0] primer_sequence = note_seq.trim_note_sequence( input_sequence, input_section.start_time, input_section.end_time) input_start_step = note_seq.quantize_to_step( input_section.start_time, steps_per_second, quantize_cutoff=0.0) else: primer_sequence = input_sequence input_start_step = 0 if primer_sequence.notes: last_end_time = max(n.end_time for n in primer_sequence.notes) else: last_end_time = 0 if last_end_time > generate_section.start_time: raise sequence_generator.SequenceGeneratorError( 'Got GenerateSection request for section that is before the end of ' 'the NoteSequence. This model can only extend sequences. Requested ' 'start time: %s, Final note end time: %s' % (generate_section.start_time, last_end_time)) # Quantize the priming sequence. quantized_sequence = note_seq.quantize_note_sequence( primer_sequence, self.steps_per_quarter) # Setting gap_bars to infinite ensures that the entire input will be used. extracted_drum_tracks, _ = drum_pipelines.extract_drum_tracks( quantized_sequence, search_start_step=input_start_step, min_bars=0, gap_bars=float('inf'), ignore_is_drum=True) assert len(extracted_drum_tracks) <= 1 start_step = note_seq.quantize_to_step(generate_section.start_time, steps_per_second, quantize_cutoff=0.0) # Note that when quantizing end_step, we set quantize_cutoff to 1.0 so it # always rounds down. This avoids generating a sequence that ends at 5.0 # seconds when the requested end time is 4.99. end_step = note_seq.quantize_to_step(generate_section.end_time, steps_per_second, quantize_cutoff=1.0) if extracted_drum_tracks and extracted_drum_tracks[0]: drums = extracted_drum_tracks[0] else: # If no drum track could be extracted, create an empty drum track that # starts 1 step before the request start_step. This will result in 1 step # of silence when the drum track is extended below. steps_per_bar = int( note_seq.steps_per_bar_in_quantized_sequence( quantized_sequence)) drums = note_seq.DrumTrack( [], start_step=max(0, start_step - 1), steps_per_bar=steps_per_bar, steps_per_quarter=self.steps_per_quarter) # Ensure that the drum track extends up to the step we want to start # generating. drums.set_length(start_step - drums.start_step) # Extract generation arguments from generator options. arg_types = { 'temperature': lambda arg: arg.float_value, 'beam_size': lambda arg: arg.int_value, 'branch_factor': lambda arg: arg.int_value, 'steps_per_iteration': lambda arg: arg.int_value } args = dict((name, value_fn(generator_options.args[name])) for name, value_fn in arg_types.items() if name in generator_options.args) generated_drums = self._model.generate_drum_track( end_step - drums.start_step, drums, **args) generated_sequence = generated_drums.to_sequence(qpm=qpm) assert (generated_sequence.total_time - generate_section.end_time) <= 1e-5 return generated_sequence
def run_with_flags(generator): """Generates drum tracks and saves them as MIDI files. Uses the options specified by the flags defined in this module. Args: generator: The DrumsRnnSequenceGenerator to use for generation. """ if not FLAGS.output_dir: tf.logging.fatal('--output_dir required') return FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir) primer_midi = None if FLAGS.primer_midi: primer_midi = os.path.expanduser(FLAGS.primer_midi) if not tf.gfile.Exists(FLAGS.output_dir): tf.gfile.MakeDirs(FLAGS.output_dir) primer_sequence = None qpm = FLAGS.qpm if FLAGS.qpm else note_seq.DEFAULT_QUARTERS_PER_MINUTE if FLAGS.primer_drums: primer_drums = note_seq.DrumTrack([ frozenset(pitches) for pitches in ast.literal_eval(FLAGS.primer_drums) ]) primer_sequence = primer_drums.to_sequence(qpm=qpm) elif primer_midi: primer_sequence = note_seq.midi_file_to_sequence_proto(primer_midi) if primer_sequence.tempos and primer_sequence.tempos[0].qpm: qpm = primer_sequence.tempos[0].qpm else: tf.logging.warning( 'No priming sequence specified. Defaulting to a single bass drum hit.' ) primer_drums = note_seq.DrumTrack([frozenset([36])]) primer_sequence = primer_drums.to_sequence(qpm=qpm) # Derive the total number of seconds to generate based on the QPM of the # priming sequence and the num_steps flag. seconds_per_step = 60.0 / qpm / generator.steps_per_quarter total_seconds = FLAGS.num_steps * seconds_per_step # Specify start/stop time for generation based on starting generation at the # end of the priming sequence and continuing until the sequence is num_steps # long. generator_options = generator_pb2.GeneratorOptions() if primer_sequence: input_sequence = primer_sequence # Set the start time to begin on the next step after the last note ends. if primer_sequence.notes: last_end_time = max(n.end_time for n in primer_sequence.notes) else: last_end_time = 0 generate_section = generator_options.generate_sections.add( start_time=last_end_time + seconds_per_step, end_time=total_seconds) if generate_section.start_time >= generate_section.end_time: tf.logging.fatal( 'Priming sequence is longer than the total number of steps ' 'requested: Priming sequence length: %s, Generation length ' 'requested: %s', generate_section.start_time, total_seconds) return else: input_sequence = music_pb2.NoteSequence() input_sequence.tempos.add().qpm = qpm generate_section = generator_options.generate_sections.add( start_time=0, end_time=total_seconds) generator_options.args['temperature'].float_value = FLAGS.temperature generator_options.args['beam_size'].int_value = FLAGS.beam_size generator_options.args['branch_factor'].int_value = FLAGS.branch_factor generator_options.args[ 'steps_per_iteration'].int_value = FLAGS.steps_per_iteration tf.logging.debug('input_sequence: %s', input_sequence) tf.logging.debug('generator_options: %s', generator_options) # Make the generate request num_outputs times and save the output as midi # files. date_and_time = time.strftime('%Y-%m-%d_%H%M%S') digits = len(str(FLAGS.num_outputs)) for i in range(FLAGS.num_outputs): generated_sequence = generator.generate(input_sequence, generator_options) midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits)) midi_path = os.path.join(FLAGS.output_dir, midi_filename) note_seq.sequence_proto_to_midi_file(generated_sequence, midi_path) tf.logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs, FLAGS.output_dir)