def main(): ######################################## ## CHANGE FILE PATH HERE TO YOUR REPO ## ######################################## path = '/Users/Kevin/Desktop/' ######################################## ##### THIS CODE COMPILES THE BUILD ##### ######################################## os.system(''.join(['cd ' + path + 'scikit-learn/;', 'python setup.py build;', 'sudo python setup.py install;', 'cd ' + path + 'scikit-learn/sklearn/neural_network/'])) ####################################### ############ OPTION PARSER ############ ####################################### parser = OptionParser(usage="usage: %prog [options] arg1 arg2", version="%prog 1.0") parser.add_option("-n", "--hidden", dest="layers", default="[10,10]", help="specifies the number of layers e.g. do -n [10,10]") parser.add_option("-f", "--filename", dest="update_file", default="param_updates.txt", help="file name to write parameter updates to",) parser.add_option("-l", "--layernum", dest="layer_num", default="0", help="specifies the layer that the update weight is randomly sampled between, for the input layer, use 0") parser.add_option("-t", "--trainsize", dest="training_size", default="30000", help="specifies the training size") parser.add_option("-s", "--testsize", dest="test_size", default="5000", help="specifies the training size") parser.add_option("-d", "--threshold", dest="threshold", default="0", help="specifies the dropout threshold for updates to weights, e.g. -d 1e-5") parser.add_option("-p", "--percentdropout", dest="dropout_percentage", default="15", help="specifies the dropout chance for updates to weights, e.g. -p 15") (options, args) = parser.parse_args() print options print args ####################################### ########### HYPERPARAMETERS ########### ####################################### training_size = int(options.training_size) test_size = int(options.test_size) hidden_layers = tuple(eval(options.layers)) max_iteration = 100 tolerance = 1e-4 batch_size = 1 ####################################### ########## FETCH MNIST DATA ########### ####################################### mnist = fetch_mldata("MNIST original") # rescale the data, use the traditional train/test split X, y = mnist.data / 255., mnist.target X_train, X_test = X[:training_size], X[training_size:training_size + test_size] y_train, y_test = y[:training_size], y[training_size:training_size + test_size] # Validate shape of training and test matrices print X_train.shape, y_train.shape, X_test.shape, y_test.shape ####################################### ####### RANDOM WEIGHT SELECTION ####### ####################################### layer_num = int(options.layer_num) if layer_num > len(hidden_layers): layer_num = 0 if layer_num == -1: print "Last layer used (layer between last hidden layer and output layer)" layer_num = len(hidden_layers) layer_sizes_array = [X_train.shape[1]] + list(hidden_layers) + [len(set(y_train))] print layer_sizes_array layer_num_array = [layer_num, random.randint(0, layer_sizes_array[layer_num] - 1), random.randint(0, layer_sizes_array[layer_num + 1] - 1)] ####################################### ######### TRAIN AND FIT MODEL ######### ####################################### mlp = MLPClassifier(hidden_layer_sizes=hidden_layers, alpha=1e-4, max_iter=max_iteration, algorithm='sgd', verbose=10, tol=tolerance, random_state=1, batch_size=batch_size) mlp.out_file_name = path + options.update_file mlp.layer_num = layer_num_array mlp.threshold_update = float(options.threshold) mlp.dropout_percentage = int(options.dropout_percentage) try: os.remove(path + options.update_file) except: 1 mlp.fit(X_train, y_train) print("Training set score: %f" % mlp.score(X_train, y_train)) print("Test set score: %f" % mlp.score(X_test, y_test)) f = open(mlp.out_file_name, 'r') lines = f.readlines() print "Lines are read" arr = np.asarray(map(lambda x: float(x), lines)) numerator = len(filter(lambda x: abs(x) < 1e-5, arr)) percent_negligible = numerator/ float(len(arr)) print "Percent neglible:", str(percent_negligible) p1 = plt.plot([i for i in range(len(arr))], arr) plt.show()