def get_checkpoint(): """Get the training dir or checkpoint path to be used by the model.""" if ((FLAGS.run_dir or FLAGS.checkpoint_file) and FLAGS.bundle_file and not should_save_generator_bundle()): raise sequence_generator.SequenceGeneratorException( 'Cannot specify both bundle_file and run_dir or checkpoint_file') if FLAGS.run_dir: train_dir = os.path.join(os.path.expanduser(FLAGS.run_dir), 'train') return train_dir elif FLAGS.checkpoint_file: return os.path.expanduser(FLAGS.checkpoint_file) else: return None
def _generate(self, generate_sequence_request): if len(generate_sequence_request.generator_options.generate_sections ) != 1: raise sequence_generator.SequenceGeneratorException( 'This model supports only 1 generate_sections message, but got %s' % (len(generate_sequence_request.generator_options. generate_sections))) generate_section = ( generate_sequence_request.generator_options.generate_sections[0]) primer_sequence = generate_sequence_request.input_sequence notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time) last_end_time = notes_by_end_time[ -1].end_time if notes_by_end_time else 0 if last_end_time > generate_section.start_time_seconds: raise sequence_generator.SequenceGeneratorException( 'Got GenerateSection request for section that is before the end of ' 'the NoteSequence. This model can only extend sequences. ' 'Requested start time: %s, Final note end time: %s' % (generate_section.start_time_seconds, notes_by_end_time[-1].end_time)) # Quantize the priming sequence. quantized_sequence = sequences_lib.QuantizedSequence() quantized_sequence.from_note_sequence(primer_sequence, self._steps_per_beat) # Setting gap_bars to infinite ensures that the entire input will be used. extracted_melodies, _ = melodies_lib.extract_melodies( quantized_sequence, min_bars=0, min_unique_pitches=1, gap_bars=float('inf'), ignore_polyphonic_notes=True) assert len(extracted_melodies) <= 1 bpm = (primer_sequence.tempos[0].bpm if primer_sequence and primer_sequence.tempos else melodies_lib.DEFAULT_BEATS_PER_MINUTE) start_step = self._seconds_to_steps( generate_section.start_time_seconds, bpm) end_step = self._seconds_to_steps(generate_section.end_time_seconds, bpm) if extracted_melodies and extracted_melodies[0]: melody = extracted_melodies[0] else: tf.logging.warn( 'No melodies were extracted from the priming sequence. ' 'Melodies will be generated from scratch.') melody = melodies_lib.MonophonicMelody() melody.from_event_list([ random.randint(self._melody_encoder_decoder.min_note, self._melody_encoder_decoder.max_note) ]) start_step += 1 transpose_amount = melody.squash( self._melody_encoder_decoder.min_note, self._melody_encoder_decoder.max_note, self._melody_encoder_decoder.transpose_to_key) # Ensure that the melody extends up to the step we want to start generating. melody.set_length(start_step) inputs = self._session.graph.get_collection('inputs')[0] initial_state = self._session.graph.get_collection('initial_state')[0] final_state = self._session.graph.get_collection('final_state')[0] softmax = self._session.graph.get_collection('softmax')[0] final_state_ = None for i in range(end_step - len(melody)): if i == 0: inputs_ = self._melody_encoder_decoder.get_inputs_batch( [melody], full_length=True) initial_state_ = self._session.run(initial_state) else: inputs_ = self._melody_encoder_decoder.get_inputs_batch( [melody]) initial_state_ = final_state_ feed_dict = {inputs: inputs_, initial_state: initial_state_} final_state_, softmax_ = self._session.run([final_state, softmax], feed_dict) self._melody_encoder_decoder.extend_melodies([melody], softmax_) melody.transpose(-transpose_amount) generate_response = generator_pb2.GenerateSequenceResponse() generate_response.generated_sequence.CopyFrom( melody.to_sequence(bpm=bpm)) return generate_response