def setUp(self): super().setUp() self.config = events_rnn_model.EventSequenceRnnConfig( None, note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MultiDrumOneHotEncoding()), contrib_training.HParams())
def testDrumsRNNPipeline(self): note_sequence = magenta.common.testing_lib.parse_test_proto( note_seq.NoteSequence, """ time_signatures: { numerator: 4 denominator: 4} tempos: { qpm: 120}""") note_seq.testing_lib.add_track_to_sequence( note_sequence, 0, [(36, 100, 0.00, 2.0), (40, 55, 2.1, 5.0), (44, 80, 3.6, 5.0), (41, 45, 5.1, 8.0), (64, 100, 6.6, 10.0), (55, 120, 8.1, 11.0), (39, 110, 9.6, 9.7), (53, 99, 11.1, 14.1), (51, 40, 12.6, 13.0), (55, 100, 14.1, 15.0), (54, 90, 15.6, 17.0), (60, 100, 17.1, 18.0)], is_drum=True) quantizer = note_sequence_pipelines.Quantizer(steps_per_quarter=4) drums_extractor = drum_pipelines.DrumsExtractor(min_bars=7, gap_bars=1.0) one_hot_encoding = note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MultiDrumOneHotEncoding()) quantized = quantizer.transform(note_sequence)[0] drums = drums_extractor.transform(quantized)[0] one_hot = pipelines_common.make_sequence_example( *one_hot_encoding.encode(drums)) expected_result = {'training_drum_tracks': [one_hot], 'eval_drum_tracks': []} pipeline_inst = drums_rnn_pipeline.get_pipeline( self.config, eval_ratio=0.0) result = pipeline_inst.transform(note_sequence) self.assertEqual(expected_result, result)
# open hi-hat [46, 67, 72, 74, 79, 81, 26, 49, 52, 55, 57, 58, # crash 51, 53, 59, 82], # ride ] # Default configurations. default_configs = { 'one_drum': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails( id='one_drum', description='Drums RNN with 2-state encoding.'), note_seq.OneHotEventSequenceEncoderDecoder( note_seq.MultiDrumOneHotEncoding( [[39] + # use hand clap as default when decoding list(range(note_seq.MIN_MIDI_PITCH, 39)) + list(range(39, note_seq.MAX_MIDI_PITCH + 1))])), contrib_training.HParams( batch_size=128, rnn_layer_sizes=[128, 128], dropout_keep_prob=0.5, clip_norm=5, learning_rate=0.001), steps_per_quarter=2), 'drum_kit': events_rnn_model.EventSequenceRnnConfig( generator_pb2.GeneratorDetails( id='drum_kit', description='Drums RNN with multiple drums and binary counters.' ), note_seq.LookbackEventSequenceEncoderDecoder(