Exemplo n.º 1
0
    def _generate(self, input_sequence, generator_options):
        if len(generator_options.input_sections) > 1:
            raise sequence_generator.SequenceGeneratorError(
                'This model supports at most one input_sections message, but got %s'
                % len(generator_options.input_sections))
        if len(generator_options.generate_sections) != 1:
            raise sequence_generator.SequenceGeneratorError(
                'This model supports only 1 generate_sections message, but got %s'
                % len(generator_options.generate_sections))

        if input_sequence and input_sequence.tempos:
            qpm = input_sequence.tempos[0].qpm
        else:
            qpm = note_seq.DEFAULT_QUARTERS_PER_MINUTE
        steps_per_second = note_seq.steps_per_quarter_to_steps_per_second(
            self.steps_per_quarter, qpm)

        generate_section = generator_options.generate_sections[0]
        if generator_options.input_sections:
            input_section = generator_options.input_sections[0]
            primer_sequence = note_seq.trim_note_sequence(
                input_sequence, input_section.start_time,
                input_section.end_time)
            input_start_step = note_seq.quantize_to_step(
                input_section.start_time,
                steps_per_second,
                quantize_cutoff=0.0)
        else:
            primer_sequence = input_sequence
            input_start_step = 0

        if primer_sequence.notes:
            last_end_time = max(n.end_time for n in primer_sequence.notes)
        else:
            last_end_time = 0
        if last_end_time > generate_section.start_time:
            raise sequence_generator.SequenceGeneratorError(
                'Got GenerateSection request for section that is before the end of '
                'the NoteSequence. This model can only extend sequences. Requested '
                'start time: %s, Final note end time: %s' %
                (generate_section.start_time, last_end_time))

        # Quantize the priming sequence.
        quantized_sequence = note_seq.quantize_note_sequence(
            primer_sequence, self.steps_per_quarter)
        # Setting gap_bars to infinite ensures that the entire input will be used.
        extracted_drum_tracks, _ = drum_pipelines.extract_drum_tracks(
            quantized_sequence,
            search_start_step=input_start_step,
            min_bars=0,
            gap_bars=float('inf'),
            ignore_is_drum=True)
        assert len(extracted_drum_tracks) <= 1

        start_step = note_seq.quantize_to_step(generate_section.start_time,
                                               steps_per_second,
                                               quantize_cutoff=0.0)
        # Note that when quantizing end_step, we set quantize_cutoff to 1.0 so it
        # always rounds down. This avoids generating a sequence that ends at 5.0
        # seconds when the requested end time is 4.99.
        end_step = note_seq.quantize_to_step(generate_section.end_time,
                                             steps_per_second,
                                             quantize_cutoff=1.0)

        if extracted_drum_tracks and extracted_drum_tracks[0]:
            drums = extracted_drum_tracks[0]
        else:
            # If no drum track could be extracted, create an empty drum track that
            # starts 1 step before the request start_step. This will result in 1 step
            # of silence when the drum track is extended below.
            steps_per_bar = int(
                note_seq.steps_per_bar_in_quantized_sequence(
                    quantized_sequence))
            drums = note_seq.DrumTrack(
                [],
                start_step=max(0, start_step - 1),
                steps_per_bar=steps_per_bar,
                steps_per_quarter=self.steps_per_quarter)

        # Ensure that the drum track extends up to the step we want to start
        # generating.
        drums.set_length(start_step - drums.start_step)

        # Extract generation arguments from generator options.
        arg_types = {
            'temperature': lambda arg: arg.float_value,
            'beam_size': lambda arg: arg.int_value,
            'branch_factor': lambda arg: arg.int_value,
            'steps_per_iteration': lambda arg: arg.int_value
        }
        args = dict((name, value_fn(generator_options.args[name]))
                    for name, value_fn in arg_types.items()
                    if name in generator_options.args)

        generated_drums = self._model.generate_drum_track(
            end_step - drums.start_step, drums, **args)
        generated_sequence = generated_drums.to_sequence(qpm=qpm)
        assert (generated_sequence.total_time -
                generate_section.end_time) <= 1e-5
        return generated_sequence
Exemplo n.º 2
0
def run_with_flags(generator):
    """Generates drum tracks and saves them as MIDI files.

  Uses the options specified by the flags defined in this module.

  Args:
    generator: The DrumsRnnSequenceGenerator to use for generation.
  """
    if not FLAGS.output_dir:
        tf.logging.fatal('--output_dir required')
        return
    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)

    primer_midi = None
    if FLAGS.primer_midi:
        primer_midi = os.path.expanduser(FLAGS.primer_midi)

    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    primer_sequence = None
    qpm = FLAGS.qpm if FLAGS.qpm else note_seq.DEFAULT_QUARTERS_PER_MINUTE
    if FLAGS.primer_drums:
        primer_drums = note_seq.DrumTrack([
            frozenset(pitches)
            for pitches in ast.literal_eval(FLAGS.primer_drums)
        ])
        primer_sequence = primer_drums.to_sequence(qpm=qpm)
    elif primer_midi:
        primer_sequence = note_seq.midi_file_to_sequence_proto(primer_midi)
        if primer_sequence.tempos and primer_sequence.tempos[0].qpm:
            qpm = primer_sequence.tempos[0].qpm
    else:
        tf.logging.warning(
            'No priming sequence specified. Defaulting to a single bass drum hit.'
        )
        primer_drums = note_seq.DrumTrack([frozenset([36])])
        primer_sequence = primer_drums.to_sequence(qpm=qpm)

    # Derive the total number of seconds to generate based on the QPM of the
    # priming sequence and the num_steps flag.
    seconds_per_step = 60.0 / qpm / generator.steps_per_quarter
    total_seconds = FLAGS.num_steps * seconds_per_step

    # Specify start/stop time for generation based on starting generation at the
    # end of the priming sequence and continuing until the sequence is num_steps
    # long.
    generator_options = generator_pb2.GeneratorOptions()
    if primer_sequence:
        input_sequence = primer_sequence
        # Set the start time to begin on the next step after the last note ends.
        if primer_sequence.notes:
            last_end_time = max(n.end_time for n in primer_sequence.notes)
        else:
            last_end_time = 0
        generate_section = generator_options.generate_sections.add(
            start_time=last_end_time + seconds_per_step,
            end_time=total_seconds)

        if generate_section.start_time >= generate_section.end_time:
            tf.logging.fatal(
                'Priming sequence is longer than the total number of steps '
                'requested: Priming sequence length: %s, Generation length '
                'requested: %s', generate_section.start_time, total_seconds)
            return
    else:
        input_sequence = music_pb2.NoteSequence()
        input_sequence.tempos.add().qpm = qpm
        generate_section = generator_options.generate_sections.add(
            start_time=0, end_time=total_seconds)
    generator_options.args['temperature'].float_value = FLAGS.temperature
    generator_options.args['beam_size'].int_value = FLAGS.beam_size
    generator_options.args['branch_factor'].int_value = FLAGS.branch_factor
    generator_options.args[
        'steps_per_iteration'].int_value = FLAGS.steps_per_iteration
    tf.logging.debug('input_sequence: %s', input_sequence)
    tf.logging.debug('generator_options: %s', generator_options)

    # Make the generate request num_outputs times and save the output as midi
    # files.
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
    digits = len(str(FLAGS.num_outputs))
    for i in range(FLAGS.num_outputs):
        generated_sequence = generator.generate(input_sequence,
                                                generator_options)

        midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
        midi_path = os.path.join(FLAGS.output_dir, midi_filename)
        note_seq.sequence_proto_to_midi_file(generated_sequence, midi_path)

    tf.logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                    FLAGS.output_dir)