def run(melody_encoder_decoder, build_graph):
    """Generates melodies and saves them as MIDI files.

  Args:
    melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder object specific
        to your model.
    build_graph: A function that when called, returns the tf.Graph object for
        your model. The function will be passed the parameters:
        (mode, hparams_string, input_size, num_classes, sequence_example_file).
        For an example usage, see models/basic_rnn/basic_rnn_graph.py.
  """
    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.run_dir:
        tf.logging.fatal("--run_dir required")
        return
    if not FLAGS.output_dir:
        tf.logging.fatal("--output_dir required")
        return

    FLAGS.run_dir = os.path.expanduser(FLAGS.run_dir)
    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
    if FLAGS.primer_midi:
        FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

    hparams = ast.literal_eval(FLAGS.hparams if FLAGS.hparams else "{}")
    hparams["batch_size"] = FLAGS.num_outputs
    hparams["dropout_keep_prob"] = 1.0
    hparams_string = repr(hparams)

    graph = build_graph(
        "generate", hparams_string, melody_encoder_decoder.input_size, melody_encoder_decoder.num_classes
    )

    train_dir = os.path.join(FLAGS.run_dir, "train")

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_melody = melodies_lib.Melody()
    bpm = melodies_lib.DEFAULT_BEATS_PER_MINUTE
    if FLAGS.primer_melody:
        primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
    elif FLAGS.primer_midi:
        primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
        if primer_sequence.tempos:
            bpm = primer_sequence.tempos[0].bpm
        extracted_melodies = melodies_lib.extract_melodies(primer_sequence, min_bars=0, min_unique_pitches=1)
        if extracted_melodies:
            primer_melody = extracted_melodies[0]
        else:
            tf.logging.info(
                "No melodies were extracted from the MIDI file %s. " "Melodies will be generated from scratch.",
                FLAGS.primer_midi,
            )

    run_generate(graph, train_dir, FLAGS.output_dir, melody_encoder_decoder, primer_melody, FLAGS.num_steps, bpm)
def main(_):
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    root.addHandler(ch)

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
    bpm = primer_sequence.tempos[0].bpm if len(
        primer_sequence.tempos) else 120.0

    extracted_melodies = melodies_lib.extract_melodies(primer_sequence,
                                                       min_bars=1,
                                                       min_unique_pitches=1)

    if not extracted_melodies:
        logging.info('No melodies were extracted from MIDI file %s' %
                     FLAGS.primer_midi)
        return

    graph = make_graph(hparams_string=FLAGS.hparams)

    checkpoint_dir = os.path.join(FLAGS.experiment_run_dir, 'train')

    generated = []
    while len(generated) < FLAGS.num_outputs:
        generated.extend(
            sampler_loop(graph, checkpoint_dir, extracted_melodies[0],
                         FLAGS.num_steps))

    for i in range(FLAGS.num_outputs):
        sequence = generated[i].to_sequence(bpm=bpm)
        midi_io.sequence_proto_to_midi_file(
            sequence,
            os.path.join(FLAGS.output_dir, 'basic_rnn_sample_%d.mid' % i))

    logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                 FLAGS.output_dir)
Exemple #3
0
def main(_):
  root = logging.getLogger()
  root.setLevel(logging.INFO)
  ch = logging.StreamHandler(sys.stdout)
  ch.setLevel(logging.INFO)
  root.addHandler(ch)

  if not os.path.isdir(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)

  primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
  bpm = primer_sequence.tempos[0].bpm if len(primer_sequence.tempos) else 120.0

  extracted_melodies = melodies_lib.extract_melodies(primer_sequence,
                                                     min_bars=1,
                                                     min_unique_pitches=1)

  if not extracted_melodies:
    logging.info('No melodies were extracted from MIDI file %s'
                 % FLAGS.primer_midi)
    return

  graph = make_graph(hparams_string=FLAGS.hparams)

  checkpoint_dir = os.path.join(FLAGS.experiment_run_dir, 'train')
  
  generated = []
  while len(generated) < FLAGS.num_outputs:
    generated.extend(sampler_loop(graph, classes_to_melody,
                                  checkpoint_dir,
                                  extracted_melodies[0],
                                  FLAGS.num_steps))

  for i in xrange(FLAGS.num_outputs):
    sequence = generated[i].to_sequence(bpm=bpm)
    midi_io.sequence_proto_to_midi_file(
        sequence,
        os.path.join(FLAGS.output_dir, 'basic_rnn_sample_%d.mid' % i))

  logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs, FLAGS.output_dir)
Exemple #4
0
    def testIsDrumDetection(self):
        """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
        sequence_proto = midi_io.midi_file_to_sequence_proto(
            self.midi_is_drum_filename)
        with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
            midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
            midi_data1 = py_midi.read_midifile(self.midi_is_drum_filename)
            midi_data2 = py_midi.read_midifile(temp_file.name)

        # Count number of channel 9 Note Ons.
        channel_counts = [0, 0]
        for index, midi_data in enumerate([midi_data1, midi_data2]):
            for track in midi_data:
                for event in track:
                    if (event.name == 'Note On' and event.velocity > 0
                            and event.channel == 9):
                        channel_counts[index] += 1
        self.assertEqual(channel_counts, [2, 2])
Exemple #5
0
def run_with_flags(melody_rnn_sequence_generator):
    """Generates melodies and saves them as MIDI files.

  Uses the options specified by the flags defined in this module. Intended to be
  called from the main function of one of the melody generator modules.

  Args:
    melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
        to your model.
  """
    tf.logging.set_verbosity(FLAGS.log)

    if not FLAGS.output_dir:
        tf.logging.fatal('--output_dir required')
        return

    FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
    if FLAGS.primer_midi:
        FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    primer_sequence = None
    bpm = FLAGS.bpm if FLAGS.bpm else melodies_lib.DEFAULT_BEATS_PER_MINUTE
    if FLAGS.primer_melody:
        primer_melody = melodies_lib.MonophonicMelody()
        primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
        primer_sequence = primer_melody.to_sequence(bpm=bpm)
    elif FLAGS.primer_midi:
        primer_sequence = midi_io.midi_file_to_sequence_proto(
            FLAGS.primer_midi)
        if primer_sequence.tempos and primer_sequence.tempos[0].bpm:
            bpm = primer_sequence.tempos[0].bpm

    # Derive the total number of seconds to generate based on the BPM of the
    # priming sequence and the num_steps flag.
    total_seconds = _steps_to_seconds(FLAGS.num_steps, bpm)

    # Specify start/stop time for generation based on starting generation at the
    # end of the priming sequence and continuing until the sequence is num_steps
    # long.
    generate_request = generator_pb2.GenerateSequenceRequest()
    if primer_sequence:
        generate_request.input_sequence.CopyFrom(primer_sequence)
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        # Set the start time to begin on the next step after the last note ends.
        notes_by_end_time = sorted(primer_sequence.notes,
                                   key=lambda n: n.end_time)
        last_end_time = notes_by_end_time[
            -1].end_time if notes_by_end_time else 0
        generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
            1, bpm)
        generate_section.end_time_seconds = total_seconds

        if generate_section.start_time_seconds >= generate_section.end_time_seconds:
            tf.logging.fatal(
                'Priming sequence is longer than the total number of steps '
                'requested: Priming sequence length: %s, Generation length '
                'requested: %s', generate_section.start_time_seconds,
                total_seconds)
            return
    else:
        generate_section = (
            generate_request.generator_options.generate_sections.add())
        generate_section.start_time_seconds = 0
        generate_section.end_time_seconds = total_seconds
        generate_request.input_sequence.tempos.add().bpm = bpm
    tf.logging.info('generate_request: %s', generate_request)

    # Make the generate request num_outputs times and save the output as midi
    # files.
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
    digits = len(str(FLAGS.num_outputs))
    for i in range(FLAGS.num_outputs):
        generate_response = melody_rnn_sequence_generator.generate(
            generate_request)

        midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
        midi_path = os.path.join(FLAGS.output_dir, midi_filename)
        midi_io.sequence_proto_to_midi_file(
            generate_response.generated_sequence, midi_path)

    tf.logging.info('Wrote %d MIDI files to %s', FLAGS.num_outputs,
                    FLAGS.output_dir)
def run(melody_encoder_decoder, build_graph):
  """Generates melodies and saves them as MIDI files.

  Args:
    melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder object specific
        to your model.
    build_graph: A function that when called, returns the tf.Graph object for
        your model. The function will be passed the parameters:
        (mode, hparams_string, input_size, num_classes, sequence_example_file).
        For an example usage, see models/basic_rnn/basic_rnn_graph.py.
  """
  tf.logging.set_verbosity(tf.logging.INFO)

  if not FLAGS.run_dir:
    tf.logging.fatal('--run_dir required')
    return
  if not FLAGS.output_dir:
    tf.logging.fatal('--output_dir required')
    return

  FLAGS.run_dir = os.path.expanduser(FLAGS.run_dir)
  FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
  if FLAGS.primer_midi:
    FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

  hparams = ast.literal_eval(FLAGS.hparams if FLAGS.hparams else '{}')
  hparams['batch_size'] = FLAGS.num_outputs
  hparams['dropout_keep_prob'] = 1.0
  hparams['temperature'] = FLAGS.temperature
  hparams_string = repr(hparams)

  graph = build_graph('generate',
                      hparams_string,
                      melody_encoder_decoder.input_size,
                      melody_encoder_decoder.num_classes)

  train_dir = os.path.join(FLAGS.run_dir, 'train')

  if not os.path.exists(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)

  primer_melody = melodies_lib.MonophonicMelody()
  bpm = melodies_lib.DEFAULT_BEATS_PER_MINUTE
  if FLAGS.primer_melody:
    primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
  elif FLAGS.primer_midi:
    primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
    quantized_sequence = sequences_lib.QuantizedSequence()
    quantized_sequence.from_note_sequence(primer_sequence,
                                          DEFAULT_STEPS_PER_BEAT)
    bpm = quantized_sequence.bpm
    extracted_melodies = melodies_lib.extract_melodies(
        quantized_sequence, min_bars=0, min_unique_pitches=1,
        gap_bars=float('inf'), ignore_polyphonic_notes=True)
    if extracted_melodies:
      primer_melody = extracted_melodies[0]
    else:
      tf.logging.info('No melodies were extracted from the MIDI file %s. '
                      'Melodies will be generated from scratch.',
                      FLAGS.primer_midi)

  run_generate(graph, train_dir, FLAGS.output_dir, melody_encoder_decoder,
               primer_melody, FLAGS.num_steps, bpm)
def run_with_flags(melody_rnn_sequence_generator):
  """Generates melodies and saves them as MIDI files.

  Uses the options specified by the flags defined in this module. Intended to be
  called from the main function of one of the melody generator modules.

  Args:
    melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
        to your model.
  """
  tf.logging.set_verbosity(tf.logging.INFO)

  if not FLAGS.output_dir:
    tf.logging.fatal('--output_dir required')
    return

  FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
  if FLAGS.primer_midi:
    FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)

  if not os.path.exists(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)

  primer_sequence = None
  if FLAGS.primer_melody:
    primer_melody = melodies_lib.MonophonicMelody()
    primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
    primer_sequence = primer_melody.to_sequence()
  elif FLAGS.primer_midi:
    primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)

  # Derive the total number of seconds to generate based on the BPM of the
  # priming sequence and the num_steps flag.
  bpm = (primer_sequence.tempos[0].bpm if primer_sequence.tempos
         else melodies_lib.DEFAULT_BEATS_PER_MINUTE)
  total_seconds = _steps_to_seconds(FLAGS.num_steps, bpm)

  # Specify start/stop time for generation based on starting generation at the
  # end of the priming sequence and continuing until the sequence is num_steps
  # long.
  generate_request = generator_pb2.GenerateSequenceRequest()
  generate_request.input_sequence.CopyFrom(primer_sequence)
  generate_section = generate_request.generator_options.generate_sections.add()
  # Set the start time to begin on the next step after the last note ends.
  notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
  last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
  generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
      1, bpm)
  generate_section.end_time_seconds = total_seconds

  if generate_section.start_time_seconds >= generate_section.end_time_seconds:
    tf.logging.fatal(
        'Priming sequence is longer than the total number of steps requested: '
        'Priming sequence length: %s, Generation length requested: %s',
        generate_section.start_time_seconds, total_seconds)
    return

  # Make the generate request num_outputs times and save the output as midi
  # files.
  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
  digits = len(str(FLAGS.num_outputs))
  for i in range(FLAGS.num_outputs):
    generate_response = melody_rnn_sequence_generator.generate(
        generate_request)

    midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
    midi_path = os.path.join(FLAGS.output_dir, midi_filename)
    midi_io.sequence_proto_to_midi_file(
        generate_response.generated_sequence, midi_path)

  tf.logging.info('Wrote %d MIDI files to %s',
                  FLAGS.num_outputs, FLAGS.output_dir)