def __init__(self, checkpoint=None, bundle=None): details = generator_pb2.GeneratorDetails(id='test_generator', description='Test Generator') super(SeuenceGenerator, self).__init__(Model(), details, checkpoint=checkpoint, bundle=bundle)
def testUseMatchingGeneratorId(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') SeuenceGenerator(bundle=bundle) bundle.generator_details.id = 'blarg' with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(bundle=bundle)
def testSpecifyEitherCheckPointOrBundle(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle) with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint=None, bundle=None) SeuenceGenerator(checkpoint='foo.ckpt') SeuenceGenerator(bundle=bundle)
def testGetBundleDetails(self): # Test with non-bundle generator. seq_gen = SeuenceGenerator(checkpoint='foo.ckpt') self.assertEqual(None, seq_gen.bundle_details) # Test with bundle-based generator. bundle_details = generator_pb2.GeneratorBundle.BundleDetails( description='bundle of joy') bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), bundle_details=bundle_details, checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') seq_gen = SeuenceGenerator(bundle=bundle) self.assertEqual(bundle_details, seq_gen.bundle_details)
50, 30, 62, 76, 83], # high tom # closed hi-hat [42, 44, 54, 68, 69, 70, 71, 73, 78, 80, 22], # open hi-hat [46, 67, 72, 74, 79, 81, 26, 49, 52, 55, 57, 58, # crash 51, 53, 59, 82], # ride ] # Default configurations. default_configs = { 'one_drum': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails( id='one_drum', description='Drums RNN with 2-state encoding.'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MultiDrumOneHotEncoding( [[39] + # use hand clap as default when decoding list(range(note_seq.MIN_MIDI_PITCH, 39)) + list(range(39, note_seq.MAX_MIDI_PITCH + 1))])), contrib_training.HParams( batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001), steps_per_quarter=2), 'drum_kit': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails(
control_encoder = note_seq.OptionalEventSequenceEncoder(control_encoder) encoder_decoder = note_seq.ConditionalEventSequenceEncoderDecoder( control_encoder, encoder_decoder) super(PerformanceRnnConfig, self).__init__( details, encoder_decoder, hparams) self.num_velocity_bins = num_velocity_bins self.control_signals = control_signals self.optional_conditioning = optional_conditioning self.note_performance = note_performance default_configs = { 'performance': PerformanceRnnConfig( generator_pb2.GeneratorDetails( id='performance', description='Performance RNN'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.PerformanceOneHotEncoding()), contrib_training.HParams( batch_size=64, rnn_layer_sizes=[512, 512, 512], dropout_keep_prob=1.0, clip_norm=3, learning_rate=0.001)), 'performance_with_dynamics': PerformanceRnnConfig( generator_pb2.GeneratorDetails( id='performance_with_dynamics', description='Performance RNN with dynamics'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.PerformanceOneHotEncoding(num_velocity_bins=32)),
branch_factor, steps_per_iteration, modify_events_callback=modify_events_callback) def polyphonic_sequence_log_likelihood(self, sequence): """Evaluate the log likelihood of a polyphonic sequence. Args: sequence: The PolyphonicSequence object for which to evaluate the log likelihood. Returns: The log likelihood of `sequence` under this model. """ return self._evaluate_log_likelihood([sequence])[0] default_configs = { 'polyphony': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails(id='polyphony', description='Polyphonic RNN'), note_seq.OneHotEventSequenceEncoderDecoder( polyphony_encoder_decoder.PolyphonyOneHotEncoding()), contrib_training.HParams(batch_size=64, rnn_layer_sizes=[256, 256, 256], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001)), }
"""Evaluate the log likelihood of a drum track under the model. Args: drums: The DrumTrack object for which to evaluate the log likelihood. Returns: The log likelihood of `drums` under this model. """ return self._evaluate_log_likelihood([drums])[0] # Default configurations. default_configs = { 'one_drum': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails( id='one_drum', description='Drums RNN with 2-state encoding.'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MultiDrumOneHotEncoding( [[39] + # use hand clap as default when decoding list(range(note_seq.MIN_MIDI_PITCH, 39)) + list(range(39, note_seq.MAX_MIDI_PITCH + 1))])), contrib_training.HParams(batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001), steps_per_quarter=2), 'drum_kit': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails( id='drum_kit',
Returns: The generated PianorollSequence object (which begins with the provided primer track). """ return self._generate_events(num_steps=num_steps, primer_events=primer_sequence, temperature=None, beam_size=beam_size, branch_factor=branch_factor, steps_per_iteration=steps_per_iteration) default_configs = { 'rnn-nade': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails(id='rnn-nade', description='RNN-NADE'), note_seq.PianorollEncoderDecoder(), contrib_training.HParams(batch_size=64, rnn_layer_sizes=[128, 128, 128], nade_hidden_units=128, dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001)), 'rnn-nade_attn': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails(id='rnn-nade_attn', description='RNN-NADE with attention.'), note_seq.PianorollEncoderDecoder(), contrib_training.HParams(batch_size=48, rnn_layer_sizes=[128, 128], attn_length=32,
and (transpose_to_key < 0 or transpose_to_key > note_seq.NOTES_PER_OCTAVE - 1)): raise ValueError('transpose_to_key must be >= 0 and <= 11. ' 'transpose_to_key is %d.' % transpose_to_key) self.min_note = min_note self.max_note = max_note self.transpose_to_key = transpose_to_key # Default configurations. default_configs = { 'basic_improv': ImprovRnnConfig( generator_pb2.GeneratorDetails( id='basic_improv', description='Basic melody-given-chords RNN with one-hot triad ' 'encoding for chords.'), note_seq.ConditionalEventSequenceEncoderDecoder( note_seq.OneHotEventSequenceEncoderDecoder( note_seq.TriadChordOneHotEncoding()), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MelodyOneHotEncoding(min_note=DEFAULT_MIN_NOTE, max_note=DEFAULT_MAX_NOTE))), contrib_training.HParams(batch_size=128, rnn_layer_sizes=[64, 64], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001)), 'attention_improv': ImprovRnnConfig( generator_pb2.GeneratorDetails(
if (transpose_to_key is not None and (transpose_to_key < 0 or transpose_to_key > note_seq.NOTES_PER_OCTAVE - 1)): raise ValueError('transpose_to_key must be >= 0 and <= 11. ' 'transpose_to_key is %d.' % transpose_to_key) self.min_note = min_note self.max_note = max_note self.transpose_to_key = transpose_to_key # Default configurations. default_configs = { 'basic_rnn': MelodyRnnConfig( generator_pb2.GeneratorDetails( id='basic_rnn', description='Melody RNN with one-hot encoding.'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MelodyOneHotEncoding(min_note=DEFAULT_MIN_NOTE, max_note=DEFAULT_MAX_NOTE)), contrib_training.HParams(batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001)), 'mono_rnn': MelodyRnnConfig(generator_pb2.GeneratorDetails( id='mono_rnn', description='Monophonic RNN with one-hot encoding.'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MelodyOneHotEncoding(min_note=0, max_note=128)), contrib_training.HParams(batch_size=128,