예제 #1
0
    def train_chord_to_melody_model(self,
                                    tt_split=0.9,
                                    epochs=100,
                                    model_name='basic_rnn'):
        '''
        Train model step - model takes in chord piano roll and outputs melody piano roll.
        :param tt_split: train test split
        :param epochs:  number of epochs to train
        :param model_name: specify which model we are training
        :return: None. Model is assigned as self.model for this generator
        '''

        # Train test split
        self.__prepare_data_tt_splited(tt_split=tt_split,
                                       model_name=model_name,
                                       src="nottingham-embed")
        # print('Chords shape: {}  Melodies shape: {}'.format(chords.shape, melodies.shape))

        # Load / train model
        if model_name == 'basic_rnn':
            if os.path.exists("basic_rnn.h5"):
                mb = ModelBuilder(self.X_train, self.Y_train, self.X_test,
                                  self.Y_test)
                model = mb.build_basic_rnn_model(
                    input_dim=self.X_train.shape[1:])
                model.load_weights("basic_rnn.h5")
            else:
                mb = ModelBuilder(self.X_train, self.Y_train, self.X_test,
                                  self.Y_test)
                model = mb.build_attention_bidirectional_rnn_model(
                    input_dim=self.X_train.shape[1:])
                model = mb.train_model(model,
                                       epochs,
                                       loss="categorical_crossentropy")
                model.save_weights("basic_rnn.h5")

        self.model = model
예제 #2
0
    def load_model(self, model_name, tt_split=0.9, is_fast_load=True):
        # clear session to avoid any errors
        K.clear_session()

        print("Chosen model: {}".format(model_name))

        if not is_fast_load:
            # Train test split
            if model_name == 'bidem' or model_name == 'attention' or model_name == "bidem_preload":
                self.__prepare_data_tt_splited(tt_split=tt_split,
                                               model_name=model_name,
                                               src='nottingham-embed')
                print('Chords shape: {}  Melodies shape: {}'.format(
                    self.X_train.shape, self.Y_train.shape))
            else:
                self.__prepare_data_tt_splited(tt_split=tt_split,
                                               model_name=model_name,
                                               src='nottingham')
                print('Chords shape: {}  Melodies shape: {}'.format(
                    self.X_train.shape, self.Y_train.shape))

        if is_fast_load:
            mb = ModelBuilder(None, None, None, None)
        else:
            mb = ModelBuilder(self.X_train, self.Y_train, self.X_test,
                              self.Y_test)

        if model_name == 'basic_rnn_normalized':
            self.model = mb.build_basic_rnn_model(input_dim=(1200, 128))
            weights_path = '../note/active_models/basic_rnn_weights_500.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        elif model_name == 'basic_rnn_unnormalized':
            self.model = mb.build_basic_rnn_model(input_dim=(1200, 128))
            weights_path = '../note/active_models/basic_rnn_weights_500_unnormalized.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        elif model_name == 'bidem':
            self.model = mb.build_bidirectional_rnn_model(input_dim=(1200, ))
            weights_path = '../note/active_models/bidem_weights_500.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        elif model_name == 'bidem_regularized':
            self.model = mb.build_bidirectional_rnn_model_no_embeddings(
                input_dim=(1200, 1))
            weights_path = '../note/active_models/bidirectional_regularized_500.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        elif model_name == 'attention':
            self.model = mb.build_attention_bidirectional_rnn_model(
                input_dim=(1200, ))
            weights_path = '../note/active_models/attention_weights_1000.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        elif model_name == 'bidem_preload':
            self.model = mb.build_bidirectional_rnn_model_no_embeddings(
                input_dim=(None, 32))
            weights_path = '../note/active_models/bidirectional_embedding_preload_100.h5'
            print('Loading ' + weights_path + '...')
            self.model.load_weights(weights_path)

        else:
            print('No model name: {}'.format(model_name))
            return

        self.model_name = model_name