コード例 #1
0
    def train_models(self,
                     training_set,
                     validation_set,
                     testing_set,
                     repeats,
                     mem_alloc=0.1,
                     training_sample_weights=None,
                     validation_sample_weights=None,
                     testing_sample_weights=None,
                     timeout_seconds=None):

        logging.info("Training %d new %s", repeat,
                     ("model" if repeats == 1 else "models"))

        logging.debug("Number of training examples:" +
                      str(len(training_set[1])))
        logging.debug("Number of validation examples:" +
                      str(len(validation_set[1])))

        sample_weights = None
        if self.weighted == True:
            if training_sample_weights is None:
                # I want each sample to weigh its distance from the mean
                mean_training_response = np.mean(training_set[2])
                training_sample_weights = np.array([
                    abs(response - mean_training_response)
                    for response in training_set[2]
                ],
                                                   dtype='float32').flatten()
                validation_sample_weights = np.array(
                    [
                        abs(response - mean_training_response)
                        for response in validation_set[2]
                    ],
                    dtype='float32').flatten()

        models = []

        model_indexes = [self.next_model_idx]
        test_mse = None

        if repeats > 1:
            additional_indexes = self.get_next_model_indexes(repeats - 1)
            model_indexes = model_indexes + additional_indexes

        for repeat_idx in range(repeats):

            model_index = model_indexes[repeat_idx]

            logging.info("Training new model (index: %d).", model_index)
            ensure_exists(self.model_folder + "/" + str(model_index))

            save_file = self.model_folder + "/" + str(
                model_index) + "/model.h5"
            metadata_file = self.model_folder + "/" + str(
                model_index) + "/metadata.json"

            keras_model = self.build_model_keras(self.architecture,
                                                 len(training_set[1][0]),
                                                 self.regulariser,
                                                 mem_alloc=mem_alloc)
            keras_model.patience = self.patience

            opt = keras.optimizers.Adam(lr=self.learning_rate)
            keras_model.compile(optimizer=opt, loss='mse')

            checkpoint_saver = ModelCheckpoint(save_file,
                                               monitor='val_loss',
                                               verbose=1,
                                               save_best_only=True,
                                               mode='min')
            monitor = TrainingMonitor(timeout_seconds)
            callbacks = [checkpoint_saver, monitor]

            if self.patience > 0:
                stopper = EarlyStopping(
                    monitor='val_loss',
                    #min_delta=0.005,
                    patience=self.patience,
                    verbose=1,
                    mode='min')
                callbacks.append(stopper)

            try:
                keras_model.fit(
                    training_set[1].astype(np.float32),
                    training_set[2].astype(np.float32).flatten(),
                    batch_size=self.batch_size,
                    epochs=self.max_training_epochs,
                    shuffle=True,
                    callbacks=callbacks,
                    sample_weight=training_sample_weights.astype(
                        np.float32).flatten(),
                    validation_data=(validation_set[1].astype(np.float32),
                                     validation_set[2].astype(
                                         np.float32).flatten(),
                                     validation_sample_weights.astype(
                                         np.float32).flatten()),
                    verbose=2)
            except KeyboardInterrupt:
                logging.warning(
                    "Stopping the model %s training due to keyboard interrupt",
                    save_file)

            keras_model.load_weights(
                save_file
            )  # load best model parameters (i.e. 'patience' epochs ago)

            logging.info("Model training completed at epoch %s (patience=%s).",
                         str(monitor.epochs), str(self.patience))

            timed_out = False
            num_epochs = monitor.epochs - self.patience
            if timeout_seconds is not None:
                if monitor.timeout_reached:
                    logging.warning("Model training timed out.")
                    num_epochs = monitor.epochs

            if testing_set is not None:

                if testing_sample_weights is None:
                    sample_weights = None
                else:
                    sample_weights = testing_sample_weights.astype(
                        np.float32).flatten()

                test_mse = keras_model.evaluate(
                    testing_set[1].astype(np.float32),
                    testing_set[2].astype(np.float32).flatten(),
                    sample_weight=sample_weights,
                    verbose=0)

                logging.info(
                    "Model errors on testing set %s had MSE %f and RMSE %f",
                    str(testing_set[0]), test_mse, math.sqrt(test_mse))

            # Create model and save its metadata file
            saved_model = Model(save_file, metadata_file, training_set[0],
                                validation_set[0], self.architecture,
                                self.learning_rate, self.patience,
                                self.batch_size, num_epochs,
                                self.max_training_epochs, self.regularizer_str,
                                self.weighted, monitor.timeout_reached)

            if testing_set is not None:
                saved_model.add_error(testing_set[0], test_mse)

            saved_model.save()

            models.append(saved_model)
            logging.info("Saved completed model to %s", metadata_file)

            backend.clear_session()

        return models