def main(): pl.init_cl(1) rng = np.random.RandomState(1234) mlp = MLP(rng, 784, 32, 10) tvX, tvY, testX, testY = get_mnist_data() tvX.shape = (60000, 784) testX.shape = (10000, 784) tvX = (tvX.astype(np.float32) / 255.0).astype(np.float32) testX = (testX / 255.0).astype(np.float32) tvYoh = pd.get_dummies(tvY).values.astype(np.float32) testYoh = pd.get_dummies(testY).values.astype(np.float32) vsplit = int(0.9 * tvX.shape[0]) trainX = tvX[:vsplit, :] validX = tvX[vsplit:, :] trainY = tvYoh[:vsplit, :] validY = tvY[vsplit:].astype(np.float32) n_epochs = 30 # 000 batch_size = 512 n_train_batches = trainY.shape[0] / batch_size epoch = 0 while epoch < n_epochs: epoch += 1 for minibatch_index in xrange(n_train_batches): mlp.train( trainX[minibatch_index * batch_size:(minibatch_index + 1) * batch_size, :], trainY[minibatch_index * batch_size:(minibatch_index + 1) * batch_size, :]) if (minibatch_index == n_train_batches - 1) and minibatch_index > 0: err = mlp.test(validX, validY) print '>>>>>>>>validation error:', err, 'batch:', minibatch_index, '/', n_train_batches err = mlp.test(testX, testY) print 'TEST ERROR:', err for param, value in mlp.params: val = value.get() print '>PARAM', param, val.shape print val print '-' * 79
def setUpClass(cls): plat.init_cl(1)
def main(): pl.init_cl(1) q = pl.qs[0] argparser = argparse.ArgumentParser() argparser.add_argument('--params', help='saved network parameters') argparser.add_argument('-v', '--validate', action='store_true', help='use last batch as validation set') args = argparser.parse_args() net_params = None if args.params: net_params = load_net(q, args.params) rng = np.random.RandomState(1234) anet = Alexnet(rng, 10, params=net_params) with open('/home/petar/datasets/cifar-10.pkl', 'rb') as f: [(X, Y), (testX, testY)] = cPickle.load(f) n_epochs = 10 batch_size = 128 n_train_batches = Y.shape[0] / batch_size n_valid_batches = 0 if args.validate: n_train_batches = int(0.8 * n_train_batches) n_valid_batches = int(0.2 * Y.shape[0] / batch_size) n_test_batches = testY.shape[0] / batch_size print 'starting training...' start_time = timeit.default_timer() # preload batches into GPU memory X_batches = [] Y_batches = [] X_validbs = [] Y_validbs = [] X_test = [] Y_test = [] for minibatch_index in xrange(n_train_batches): X_batches.append( clarray.to_device( pl.qs[0], X[minibatch_index * batch_size:(minibatch_index + 1) * batch_size])) Y_batches.append( clarray.to_device( pl.qs[0], pd.get_dummies( Y[minibatch_index * batch_size:(minibatch_index + 1) * batch_size]).values.astype(np.float32))) for minibatch_index in xrange(n_train_batches, n_train_batches + n_valid_batches): X_validbs.append( clarray.to_device( pl.qs[0], X[minibatch_index * batch_size:(minibatch_index + 1) * batch_size])) Y_validbs.append( clarray.to_device( pl.qs[0], Y[minibatch_index * batch_size:(minibatch_index + 1) * batch_size])) for minibatch_index in xrange(n_test_batches): X_test.append( clarray.to_device( pl.qs[0], testX[minibatch_index * batch_size:(minibatch_index + 1) * batch_size])) Y_test.append( clarray.to_device( pl.qs[0], testY[minibatch_index * batch_size:(minibatch_index + 1) * batch_size])) epoch = 0 lrn_rate = 0.01 momentum = 0.0 while epoch < n_epochs: epoch += 1 print 'epoch:', epoch, 'of', n_epochs, 'lr:', lrn_rate, 'm:', momentum for mbi in xrange(n_train_batches): print '\r>training batch', mbi, 'of', n_train_batches, sys.stdout.flush() anet.train(X_batches[mbi], Y_batches[mbi], lrn_rate, momentum) if args.validate and mbi % 65 == 0: verr = np.mean([ float(anet.test(X_validbs[vbi], Y_validbs[vbi])) for vbi in range(n_valid_batches) ]) print '\rvalidation error:', verr if epoch % 8 == 0 and epoch > 0: lrn_rate /= 10.0 # momentum += 0.1 print print '=' * 70 # print '>>final wc:\n', anet.conv1.params[0][1] print 'test error:', es = [] for mbi in xrange(n_test_batches): er = anet.test(X_test[mbi], Y_test[mbi]) # print 'test batch', mbi, 'error:', er es.append(er) print np.mean([float(e) for e in es]) end_time = timeit.default_timer() print 'ran for %.2fm' % ((end_time - start_time) / 60.) save_net(anet, 'alexnet_cifar.pkl')
def setUpClass(cls): clplatf.init_cl(1)