def test_mlp(X, Y): dbn_network: UnsupervisedDBN = UnsupervisedDBN.load(DBN_NETWORK_FILENAME) mlp_network: MLPClassifier = joblib.load(MLP_NETWORK_FILENAME) first_positions = dbn_network.transform(X[:, 0]) second_positions = dbn_network.transform(X[:, 1]) X = np.concatenate((first_positions, second_positions), axis=1) return mlp_network.score(X, Y)
def predict_mlp(first, second): First = np.array(first) Second = np.array(second) dbn_network: UnsupervisedDBN = UnsupervisedDBN.load(DBN_NETWORK_FILENAME) mlp_network: MLPClassifier = joblib.load(MLP_NETWORK_FILENAME) first_position = dbn_network.transform(First) second_position = dbn_network.transform(Second) X = np.concatenate((first_position, second_position)) result = mlp_network.predict(X.reshape(1, -1)) return result.flatten().tolist()
def train_mlp(X, Y): dbn_network = UnsupervisedDBN.load(DBN_NETWORK_FILENAME) first_positions = dbn_network.transform(X[:, 0]) second_positions = dbn_network.transform(X[:, 1]) X = np.concatenate((first_positions, second_positions), axis=1) print(f'Shape after Concatenation = {X.shape}') mlp_network = MLPClassifier(hidden_layer_sizes=(90, 45, 20, 2), learning_rate='constant', learning_rate_init=1e-2, max_iter=1000) mlp_network.fit(X, Y) joblib.dump(mlp_network, MLP_NETWORK_FILENAME)
def __init__(self, load=False): if load: self.bands = pickle.load(open("bands.pickle", "rb")) self.dbn = UnsupervisedDBN.load("dbn.pickle") self.prepca = pickle.load(open("prepca.pickle", "rb")) self.pca = pickle.load(open("pca.pickle", "rb")) else: self.dbn = UnsupervisedDBN(hidden_layers_structure=[50, 50, 50], batch_size=1024, learning_rate_rbm=0.001, n_epochs_rbm=5, contrastive_divergence_iter=2, activation_function='sigmoid') self.bands = [ {"start": 0, "end": 150, "pca_components":30, "energy_factor": 5}, {"start": 120, "end":300, "pca_components":70, "energy_factor": 1}, {"start": 250, "end":513, "pca_components":60, "energy_factor": 0.5} ] self.prepca = [PCA(n_components=band["pca_components"]) for band in self.bands] self.pca = PCA(n_components=16)