def train_chord_to_melody_model(self, tt_split=0.9, epochs=100, model_name='basic_rnn'): ''' Train model step - model takes in chord piano roll and outputs melody piano roll. :param tt_split: train test split :param epochs: number of epochs to train :param model_name: specify which model we are training :return: None. Model is assigned as self.model for this generator ''' # Train test split self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src="nottingham-embed") # print('Chords shape: {} Melodies shape: {}'.format(chords.shape, melodies.shape)) # Load / train model if model_name == 'basic_rnn': if os.path.exists("basic_rnn.h5"): mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) model = mb.build_basic_rnn_model( input_dim=self.X_train.shape[1:]) model.load_weights("basic_rnn.h5") else: mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) model = mb.build_attention_bidirectional_rnn_model( input_dim=self.X_train.shape[1:]) model = mb.train_model(model, epochs, loss="categorical_crossentropy") model.save_weights("basic_rnn.h5") self.model = model
def load_model(self, model_name, tt_split=0.9, is_fast_load=True): # clear session to avoid any errors K.clear_session() print("Chosen model: {}".format(model_name)) if not is_fast_load: # Train test split if model_name == 'bidem' or model_name == 'attention' or model_name == "bidem_preload": self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src='nottingham-embed') print('Chords shape: {} Melodies shape: {}'.format( self.X_train.shape, self.Y_train.shape)) else: self.__prepare_data_tt_splited(tt_split=tt_split, model_name=model_name, src='nottingham') print('Chords shape: {} Melodies shape: {}'.format( self.X_train.shape, self.Y_train.shape)) if is_fast_load: mb = ModelBuilder(None, None, None, None) else: mb = ModelBuilder(self.X_train, self.Y_train, self.X_test, self.Y_test) if model_name == 'basic_rnn_normalized': self.model = mb.build_basic_rnn_model(input_dim=(1200, 128)) weights_path = '../note/active_models/basic_rnn_weights_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'basic_rnn_unnormalized': self.model = mb.build_basic_rnn_model(input_dim=(1200, 128)) weights_path = '../note/active_models/basic_rnn_weights_500_unnormalized.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem': self.model = mb.build_bidirectional_rnn_model(input_dim=(1200, )) weights_path = '../note/active_models/bidem_weights_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem_regularized': self.model = mb.build_bidirectional_rnn_model_no_embeddings( input_dim=(1200, 1)) weights_path = '../note/active_models/bidirectional_regularized_500.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'attention': self.model = mb.build_attention_bidirectional_rnn_model( input_dim=(1200, )) weights_path = '../note/active_models/attention_weights_1000.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) elif model_name == 'bidem_preload': self.model = mb.build_bidirectional_rnn_model_no_embeddings( input_dim=(None, 32)) weights_path = '../note/active_models/bidirectional_embedding_preload_100.h5' print('Loading ' + weights_path + '...') self.model.load_weights(weights_path) else: print('No model name: {}'.format(model_name)) return self.model_name = model_name