Esempio n. 1
0
  def CheckReadWriteMidi(self, filename):
    """Test writing to a MIDI file and comparing it to the original Sequence."""
    source_midi = pretty_midi.PrettyMIDI(filename)
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

    # Translate the NoteSequence to MIDI and write to a file.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)

      # Read it back in and compare to source.
      created_midi = pretty_midi.PrettyMIDI(temp_file.name)

    self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
Esempio n. 2
0
    def CheckReadWriteMidi(self, filename):
        """Test writing to a MIDI file and comparing it to the original Sequence."""
        source_midi = pretty_midi.PrettyMIDI(filename)
        sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

        # Translate the NoteSequence to MIDI and write to a file.
        with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)

            # Read it back in and compare to source.
            created_midi = pretty_midi.PrettyMIDI(temp_file.name)

        self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
Esempio n. 3
0
def main(_):
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    root.addHandler(ch)

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
    bpm = primer_sequence.tempos[0].bpm if len(
        primer_sequence.tempos) else 120.0

    extracted_melodies = melodies_lib.extract_melodies(primer_sequence,
                                                       min_bars=1,
                                                       min_unique_pitches=1)

    if not extracted_melodies:
        logging.info('No melodies were extracted from MIDI file %s' %
                     FLAGS.primer_midi)
        return

    graph = make_graph(hparams_string=FLAGS.hparams)

    checkpoint_dir = os.path.join(FLAGS.experiment_run_dir, 'train')

    generated = []
    while len(generated) < FLAGS.num_outputs:
        generated.extend(
            sampler_loop(graph, checkpoint_dir, extracted_melodies[0],
                         FLAGS.num_steps))

    for i in range(FLAGS.num_outputs):
        sequence = generated[i].to_sequence(bpm=bpm)
        midi_io.sequence_proto_to_midi_file(
            sequence,
            os.path.join(FLAGS.output_dir, 'basic_rnn_sample_%d.mid' % i))

    logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                 FLAGS.output_dir)
Esempio n. 4
0
def main(_):
  root = logging.getLogger()
  root.setLevel(logging.INFO)
  ch = logging.StreamHandler(sys.stdout)
  ch.setLevel(logging.INFO)
  root.addHandler(ch)

  if not os.path.isdir(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)

  primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
  bpm = primer_sequence.tempos[0].bpm if len(primer_sequence.tempos) else 120.0

  extracted_melodies = melodies_lib.extract_melodies(primer_sequence,
                                                     min_bars=1,
                                                     min_unique_pitches=1)

  if not extracted_melodies:
    logging.info('No melodies were extracted from MIDI file %s'
                 % FLAGS.primer_midi)
    return

  graph = make_graph(hparams_string=FLAGS.hparams)

  checkpoint_dir = os.path.join(FLAGS.experiment_run_dir, 'train')
  
  generated = []
  while len(generated) < FLAGS.num_outputs:
    generated.extend(sampler_loop(graph, classes_to_melody,
                                  checkpoint_dir,
                                  extracted_melodies[0],
                                  FLAGS.num_steps))

  for i in xrange(FLAGS.num_outputs):
    sequence = generated[i].to_sequence(bpm=bpm)
    midi_io.sequence_proto_to_midi_file(
        sequence,
        os.path.join(FLAGS.output_dir, 'basic_rnn_sample_%d.mid' % i))

  logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs, FLAGS.output_dir)
Esempio n. 5
0
    def testIsDrumDetection(self):
        """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
        sequence_proto = midi_io.midi_file_to_sequence_proto(
            self.midi_is_drum_filename)
        with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
            midi_data1 = py_midi.read_midifile(self.midi_is_drum_filename)
            midi_data2 = py_midi.read_midifile(temp_file.name)

        # Count number of channel 9 Note Ons.
        channel_counts = [0, 0]
        for index, midi_data in enumerate([midi_data1, midi_data2]):
            for track in midi_data:
                for event in track:
                    if (event.name == 'Note On' and event.velocity > 0
                            and event.channel == 9):
                        channel_counts[index] += 1
        self.assertEqual(channel_counts, [2, 2])
Esempio n. 6
0
    def CheckReadWriteMidi(self, filename):
        """Test writing to a MIDI file and comparing it to the original Sequence."""

        # TODO(deck): The input MIDI file is opened in pretty-midi and
        # re-written to a temp file, sanitizing the MIDI data (reordering
        # note ons, etc). Issue 85 in the pretty-midi GitHub
        # (http://github.com/craffel/pretty-midi/issues/85) requests that
        # this sanitization be available outside of the context of a file
        # write. If that is implemented, this rewrite code should be
        # modified or deleted.
        with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
            original_midi = pretty_midi.PrettyMIDI(filename)
            original_midi.write(rewrite_file.name)
            source_midi = pretty_midi.PrettyMIDI(rewrite_file.name)
            sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

        # Translate the NoteSequence to MIDI and write to a file.
        with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)

            # Read it back in and compare to source.
            created_midi = pretty_midi.PrettyMIDI(temp_file.name)

        self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
Esempio n. 7
0
def run_with_flags(melody_rnn_sequence_generator):
    """Generates melodies and saves them as MIDI files.

  Uses the options specified by the flags defined in this module. Intended to be
  called from the main function of one of the melody generator modules.

  Args:
    melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
        to your model.
  """
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.output_dir:
        tf.logging.fatal('--output_dir required')
        return

    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
    if FLAGS.primer_midi:
        FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_sequence = None
    bpm = FLAGS.bpm if FLAGS.bpm else melodies_lib.DEFAULT_BEATS_PER_MINUTE
    if FLAGS.primer_melody:
        primer_melody = melodies_lib.MonophonicMelody()
        primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
        primer_sequence = primer_melody.to_sequence(bpm=bpm)
    elif FLAGS.primer_midi:
        primer_sequence = midi_io.midi_file_to_sequence_proto(
            FLAGS.primer_midi)
        if primer_sequence.tempos and primer_sequence.tempos[0].bpm:
            bpm = primer_sequence.tempos[0].bpm

    # Derive the total number of seconds to generate based on the BPM of the
    # priming sequence and the num_steps flag.
    total_seconds = _steps_to_seconds(FLAGS.num_steps, bpm)

    # Specify start/stop time for generation based on starting generation at the
    # end of the priming sequence and continuing until the sequence is num_steps
    # long.
    generate_request = generator_pb2.GenerateSequenceRequest()
    if primer_sequence:
        generate_request.input_sequence.CopyFrom(primer_sequence)
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        # Set the start time to begin on the next step after the last note ends.
        notes_by_end_time = sorted(primer_sequence.notes,
                                   key=lambda n: n.end_time)
        last_end_time = notes_by_end_time[
            -1].end_time if notes_by_end_time else 0
        generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
            1, bpm)
        generate_section.end_time_seconds = total_seconds

        if generate_section.start_time_seconds >= generate_section.end_time_seconds:
            tf.logging.fatal(
                'Priming sequence is longer than the total number of steps '
                'requested: Priming sequence length: %s, Generation length '
                'requested: %s', generate_section.start_time_seconds,
                total_seconds)
            return
    else:
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        generate_section.start_time_seconds = 0
        generate_section.end_time_seconds = total_seconds
        generate_request.input_sequence.tempos.add().bpm = bpm
    tf.logging.info('generate_request: %s', generate_request)

    # Make the generate request num_outputs times and save the output as midi
    # files.
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
    digits = len(str(FLAGS.num_outputs))
    for i in range(FLAGS.num_outputs):
        generate_response = melody_rnn_sequence_generator.generate(
            generate_request)

        midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
        midi_path = os.path.join(FLAGS.output_dir, midi_filename)
        midi_io.sequence_proto_to_midi_file(
            generate_response.generated_sequence, midi_path)

    tf.logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                    FLAGS.output_dir)
Esempio n. 8
0
def run_generate(graph, train_dir, output_dir, melody_encoder_decoder,
                 primer_melody, num_steps, bpm):
  """Generates melodies and saves them as MIDI files.

  Args:
    graph: A tf.Graph object containing the model.
    train_dir: The path to the directory where the latest checkpoint will be
        loaded from.
    output_dir: The path to the directory where MIDI files will be saved to.
    melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder object.
    primer_melody: A melodies_lib.MonophonicMelody object that will be used as
        the priming melody. If the priming melody is empty, melodies will be
        generated from scratch.
    num_steps: The total number of steps the final melodies should be,
        priming melody + generated steps. Each step is a 16th of a bar.
    bpm: The tempo in beats per minute that the generated MIDI files will have.
  """
  inputs = graph.get_collection('inputs')[0]
  initial_state = graph.get_collection('initial_state')[0]
  final_state = graph.get_collection('final_state')[0]
  softmax = graph.get_collection('softmax')[0]
  batch_size = softmax.get_shape()[0].value

  transpose_amount = primer_melody.squash(
      melody_encoder_decoder.min_note, melody_encoder_decoder.max_note,
      melody_encoder_decoder.transpose_to_key)

  melodies = []
  for _ in xrange(batch_size):
    melody = melodies_lib.MonophonicMelody()
    if primer_melody.events:
      melody.from_event_list(primer_melody.events)
    else:
      melody.events = [random.randint(melody_encoder_decoder.min_note,
                                      melody_encoder_decoder.max_note)]
    melodies.append(melody)

  with graph.as_default():
    saver = tf.train.Saver()
    with tf.Session() as sess:
      checkpoint_file = tf.train.latest_checkpoint(train_dir)
      tf.logging.info('Checkpoint used: %s', checkpoint_file)
      tf.logging.info('Generating melodies...')
      saver.restore(sess, checkpoint_file)

      final_state_ = None
      for i in xrange(num_steps - len(primer_melody)):
        if i == 0:
          inputs_ = melody_encoder_decoder.get_inputs_batch(melodies,
                                                            full_length=True)
          initial_state_ = sess.run(initial_state)
        else:
          inputs_ = melody_encoder_decoder.get_inputs_batch(melodies)
          initial_state_ = final_state_

        feed_dict = {inputs: inputs_, initial_state: initial_state_}
        final_state_, softmax_ = sess.run([final_state, softmax], feed_dict)
        melody_encoder_decoder.extend_melodies(melodies, softmax_)

  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
  digits = len(str(len(melodies)))
  for i, melody in enumerate(melodies):
    melody.transpose(-transpose_amount)
    sequence = melody.to_sequence(bpm=bpm)
    midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
    midi_path = os.path.join(output_dir, midi_filename)
    midi_io.sequence_proto_to_midi_file(sequence, midi_path)

  tf.logging.info('Wrote %d MIDI files to %s', len(melodies), output_dir)
def run_with_flags(melody_rnn_sequence_generator):
  """Generates melodies and saves them as MIDI files.

  Uses the options specified by the flags defined in this module. Intended to be
  called from the main function of one of the melody generator modules.

  Args:
    melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
        to your model.
  """
  tf.logging.set_verbosity(tf.logging.INFO)

  if not FLAGS.output_dir:
    tf.logging.fatal('--output_dir required')
    return

  FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
  if FLAGS.primer_midi:
    FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

  if not os.path.exists(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)

  primer_sequence = None
  if FLAGS.primer_melody:
    primer_melody = melodies_lib.MonophonicMelody()
    primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
    primer_sequence = primer_melody.to_sequence()
  elif FLAGS.primer_midi:
    primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)

  # Derive the total number of seconds to generate based on the BPM of the
  # priming sequence and the num_steps flag.
  bpm = (primer_sequence.tempos[0].bpm if primer_sequence.tempos
         else melodies_lib.DEFAULT_BEATS_PER_MINUTE)
  total_seconds = _steps_to_seconds(FLAGS.num_steps, bpm)

  # Specify start/stop time for generation based on starting generation at the
  # end of the priming sequence and continuing until the sequence is num_steps
  # long.
  generate_request = generator_pb2.GenerateSequenceRequest()
  generate_request.input_sequence.CopyFrom(primer_sequence)
  generate_section = generate_request.generator_options.generate_sections.add()
  # Set the start time to begin on the next step after the last note ends.
  notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
  last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
  generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
      1, bpm)
  generate_section.end_time_seconds = total_seconds

  if generate_section.start_time_seconds >= generate_section.end_time_seconds:
    tf.logging.fatal(
        'Priming sequence is longer than the total number of steps requested: '
        'Priming sequence length: %s, Generation length requested: %s',
        generate_section.start_time_seconds, total_seconds)
    return

  # Make the generate request num_outputs times and save the output as midi
  # files.
  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
  digits = len(str(FLAGS.num_outputs))
  for i in range(FLAGS.num_outputs):
    generate_response = melody_rnn_sequence_generator.generate(
        generate_request)

    midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
    midi_path = os.path.join(FLAGS.output_dir, midi_filename)
    midi_io.sequence_proto_to_midi_file(
        generate_response.generated_sequence, midi_path)

  tf.logging.info('Wrote %d MIDI files to %s',
                  FLAGS.num_outputs, FLAGS.output_dir)