def get_checkpoint():
    """Get the training dir or checkpoint path to be used by the model."""
    if ((FLAGS.run_dir or FLAGS.checkpoint_file) and FLAGS.bundle_file
            and not should_save_generator_bundle()):
        raise sequence_generator.SequenceGeneratorException(
            'Cannot specify both bundle_file and run_dir or checkpoint_file')
    if FLAGS.run_dir:
        train_dir = os.path.join(os.path.expanduser(FLAGS.run_dir), 'train')
        return train_dir
    elif FLAGS.checkpoint_file:
        return os.path.expanduser(FLAGS.checkpoint_file)
    else:
        return None
Beispiel #2
0
    def _generate(self, generate_sequence_request):
        if len(generate_sequence_request.generator_options.generate_sections
               ) != 1:
            raise sequence_generator.SequenceGeneratorException(
                'This model supports only 1 generate_sections message, but got %s'
                % (len(generate_sequence_request.generator_options.
                       generate_sections)))

        generate_section = (
            generate_sequence_request.generator_options.generate_sections[0])
        primer_sequence = generate_sequence_request.input_sequence

        notes_by_end_time = sorted(primer_sequence.notes,
                                   key=lambda n: n.end_time)
        last_end_time = notes_by_end_time[
            -1].end_time if notes_by_end_time else 0
        if last_end_time > generate_section.start_time_seconds:
            raise sequence_generator.SequenceGeneratorException(
                'Got GenerateSection request for section that is before the end of '
                'the NoteSequence. This model can only extend sequences. '
                'Requested start time: %s, Final note end time: %s' %
                (generate_section.start_time_seconds,
                 notes_by_end_time[-1].end_time))

        # Quantize the priming sequence.
        quantized_sequence = sequences_lib.QuantizedSequence()
        quantized_sequence.from_note_sequence(primer_sequence,
                                              self._steps_per_beat)
        # Setting gap_bars to infinite ensures that the entire input will be used.
        extracted_melodies, _ = melodies_lib.extract_melodies(
            quantized_sequence,
            min_bars=0,
            min_unique_pitches=1,
            gap_bars=float('inf'),
            ignore_polyphonic_notes=True)
        assert len(extracted_melodies) <= 1

        bpm = (primer_sequence.tempos[0].bpm
               if primer_sequence and primer_sequence.tempos else
               melodies_lib.DEFAULT_BEATS_PER_MINUTE)
        start_step = self._seconds_to_steps(
            generate_section.start_time_seconds, bpm)
        end_step = self._seconds_to_steps(generate_section.end_time_seconds,
                                          bpm)

        if extracted_melodies and extracted_melodies[0]:
            melody = extracted_melodies[0]
        else:
            tf.logging.warn(
                'No melodies were extracted from the priming sequence. '
                'Melodies will be generated from scratch.')
            melody = melodies_lib.MonophonicMelody()
            melody.from_event_list([
                random.randint(self._melody_encoder_decoder.min_note,
                               self._melody_encoder_decoder.max_note)
            ])
            start_step += 1

        transpose_amount = melody.squash(
            self._melody_encoder_decoder.min_note,
            self._melody_encoder_decoder.max_note,
            self._melody_encoder_decoder.transpose_to_key)

        # Ensure that the melody extends up to the step we want to start generating.
        melody.set_length(start_step)

        inputs = self._session.graph.get_collection('inputs')[0]
        initial_state = self._session.graph.get_collection('initial_state')[0]
        final_state = self._session.graph.get_collection('final_state')[0]
        softmax = self._session.graph.get_collection('softmax')[0]

        final_state_ = None
        for i in range(end_step - len(melody)):
            if i == 0:
                inputs_ = self._melody_encoder_decoder.get_inputs_batch(
                    [melody], full_length=True)
                initial_state_ = self._session.run(initial_state)
            else:
                inputs_ = self._melody_encoder_decoder.get_inputs_batch(
                    [melody])
                initial_state_ = final_state_

            feed_dict = {inputs: inputs_, initial_state: initial_state_}
            final_state_, softmax_ = self._session.run([final_state, softmax],
                                                       feed_dict)
            self._melody_encoder_decoder.extend_melodies([melody], softmax_)

        melody.transpose(-transpose_amount)

        generate_response = generator_pb2.GenerateSequenceResponse()
        generate_response.generated_sequence.CopyFrom(
            melody.to_sequence(bpm=bpm))
        return generate_response