def testIsTraining(self): converter = data.OneHotDrumsConverter( steps_per_quarter=1, slice_bars=1, max_tensors_per_notesequence=2) self.is_training = True self.assertEqual(2, len(converter.to_tensors(self.sequence)[0])) converter.max_tensors_per_notesequence = None self.assertEqual(5, len(converter.to_tensors(self.sequence)[0]))
def testMaxOutputsPerNoteSequence(self): converter = data.OneHotDrumsConverter( steps_per_quarter=1, slice_bars=1, max_tensors_per_notesequence=2) self.assertEqual(2, len(converter.to_tensors(self.sequence)[0])) converter.max_tensors_per_notesequence = 3 self.assertEqual(3, len(converter.to_tensors(self.sequence)[0])) converter.max_tensors_per_notesequence = 100 self.assertEqual(5, len(converter.to_tensors(self.sequence)[0]))
def testToNoteSequence(self): converter = data.OneHotDrumsConverter( steps_per_quarter=1, slice_bars=2, max_tensors_per_notesequence=1) _, output_tensors = converter.to_tensors( filter_instrument(self.sequence, 1)) sequences = converter.to_notesequences(output_tensors) self.assertEqual(1, len(sequences)) expected_sequence = music_pb2.NoteSequence(ticks_per_quarter=220) expected_sequence.tempos.add(qpm=120) testing_lib.add_track_to_sequence( expected_sequence, 9, [(38, 80, 0.5, 1.0), (48, 80, 2.0, 2.5), (49, 80, 2.0, 2.5), (51, 80, 3.5, 4.0)], is_drum=True) self.assertProtoEquals(expected_sequence, sequences[0])
config_map['cat-drums_2bar_small'] = Config( model=MusicVAE(lstm_models.BidirectionalLstmEncoder(), lstm_models.CategoricalLstmDecoder()), hparams=merge_hparams( lstm_models.get_default_hparams(), HParams( batch_size=512, max_seq_len=32, # 2 bars w/ 16 steps per bar z_size=256, enc_rnn_size=[512], dec_rnn_size=[256, 256], )), note_sequence_augmenter=None, note_sequence_converter=data.OneHotDrumsConverter( max_bars=100, # Truncate long drum sequences before slicing. slice_bars=2, steps_per_quarter=4, binary_input=True), train_examples_path=None, eval_examples_path=None, ) config_map['cat-drums_2bar_big'] = Config( model=MusicVAE(lstm_models.BidirectionalLstmEncoder(), lstm_models.CategoricalLstmDecoder()), hparams=merge_hparams( lstm_models.get_default_hparams(), HParams( batch_size=512, max_seq_len=32, # 2 bars w/ 16 steps per bar z_size=512,