def train(snapshotroot, device, forestType, numTrees, depth): xtrain, ytrain, xtest, ytest = datasets.load_usps() xtrain = np.reshape(xtrain, [-1, 256]) xtest = np.reshape(xtest, [-1, 256]) # XXX: Other papers use val = test for this data set xval = xtest yval = ytest # Transfer this data to the device xtrain = torch.from_numpy(xtrain).type(torch.float32).to(device) ytrain = torch.from_numpy(ytrain).type(torch.long).to(device) xval = torch.from_numpy(xval).type(torch.float32).to(device) yval = torch.from_numpy(yval).type(torch.long).to(device) xtest = torch.from_numpy(xtest).type(torch.float32).to(device) ytest = torch.from_numpy(ytest).type(torch.long).to(device) net = Net(forestType, numTrees, depth).to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(net.parameters(), lr=0.001) # Count parameters numParams = sum(params.numel() for params in net.parameters()) numTrainable = sum(params.numel() for params in net.parameters() if params.requires_grad) print( f"There are {numParams} parameters total in this model ({numTrainable} are trainable)" ) numEpochs = 200 batchSize = 23 indices = [i for i in range(xtrain.shape[0])] bestEpoch = numEpochs - 1 bestAccuracy = 0.0 bestLoss = 1000.0 valLosses = np.zeros([numEpochs]) for epoch in range(numEpochs): random.shuffle(indices) xtrain = xtrain[indices, :] ytrain = ytrain[indices] runningLoss = 0.0 count = 0 for xbatch, ybatch in batches(xtrain, ytrain, batchSize): #t = time.time() optimizer.zero_grad() outputs = net(xbatch) loss = criterion(outputs, ybatch) loss.backward() optimizer.step() runningLoss += loss count += 1 #print(f"elapsed = {time.time() - t}, count = {count}") meanLoss = runningLoss / count snapshotFile = os.path.join(snapshotroot, f"epoch_{epoch}") torch.save(net.state_dict(), snapshotFile) runningLoss = 0.0 count = 0 with torch.no_grad(): net.train(False) #for xbatch, ybatch in batches(xval, yval, batchSize): for xbatch, ybatch in zip([xval], [yval]): outputs = net(xbatch) loss = criterion(outputs, ybatch) runningLoss += loss count += 1 net.train(True) valLoss = runningLoss / count if valLoss < bestLoss: bestLoss = valLoss bestEpoch = epoch print( f"Info: Epoch = {epoch}, loss = {meanLoss}, validation loss = {valLoss}", flush=True) valLosses[epoch] = valLoss snapshotFile = os.path.join(snapshotroot, f"epoch_{bestEpoch}") net = Net(forestType, numTrees, depth) net.load_state_dict(torch.load(snapshotFile, map_location="cpu")) net = net.to(device) totalCorrect = 0 count = 0 with torch.no_grad(): net.train(False) #for xbatch, ybatch in batches(xtest, ytest, batchSize): for xbatch, ybatch in zip([xtest], [ytest]): outputs = net(xbatch) outputs = torch.argmax(outputs, dim=1) tmpCorrect = torch.sum(outputs == ybatch) totalCorrect += tmpCorrect count += xbatch.shape[0] accuracy = float(totalCorrect) / float(count) print( f"Info: Best epoch = {bestEpoch}, test accuracy = {accuracy}, misclassification rate = {1.0 - accuracy}", flush=True) return accuracy, valLosses
parser.add_argument('--batch_size', default=256, type=int) parser.add_argument('--epochs', default=200, type=int) parser.add_argument('--save_dir', default='results/temp', type=str) args = parser.parse_args() print(args) import os if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # load dataset from datasets import load_mnist, load_usps if args.dataset == 'mnist': x, y = load_mnist() elif args.dataset == 'usps': x, y = load_usps('data/usps') # define the model model = CAE(input_shape=x.shape[1:], filters=[32, 64, 128, 10]) plot_model(model, to_file=args.save_dir + '/%s-pretrain-model.png' % args.dataset, show_shapes=True) model.summary() # compile the model and callbacks optimizer = 'adam' model.compile(optimizer=optimizer, loss='mse') from keras.callbacks import CSVLogger csv_logger = CSVLogger(args.save_dir + '/%s-pretrain-log.csv' % args.dataset) # begin training t0 = time() model.fit(x, x, batch_size=args.batch_size, epochs=args.epochs, callbacks=[csv_logger])
def sdec(dataset="mnist", gamma=0.1, beta=1, maxiter=2e4, update_interval=20, tol=0.00001, batch_size=256): """arguements: dataset:choice the datasets that you want to run gamma: The Lambda in the lecture beta: the proportion of information we have known about the sample """ maxiter = maxiter gamma = gamma update_interval = update_interval tol = tol beta = beta batch_size = batch_size ae_weights = ("ae_weights/" + dataset + "_ae_weights/" + dataset + "_ae_weights.h5") # load dataset from datasets import load_mnist, load_usps, load_stl, load_cifar if dataset == 'mnist': # recommends: n_clusters=10, update_interval=140 x, y = load_mnist('./data/mnist/mnist.npz') update_interval = 140 elif dataset == 'usps': # recommends: n_clusters=10, update_interval=30 x, y = load_usps('data/usps') update_interval = 30 elif dataset == "stl": import numpy as np x, y = load_stl() update_interval = 20 elif dataset == "cifar_10": x, y = load_cifar() update_interval = 40 beta = beta print gamma, dataset, beta # prepare the SDEC model try: count = Counter(y) except: count = Counter(y[:, 0]) n_clusters = len(count) save_dir = 'results/sdec_dataset:' + dataset + " gamma:" + str(gamma) laster_batch_size = x.shape[0] % batch_size dec = SDEC(dims=[x.shape[-1], 500, 500, 2000, 10], n_clusters=n_clusters, N=x.shape[0], x=x, batch_size=batch_size, laster_batch_size=laster_batch_size, gamma=gamma, beta=beta) dec.initialize_model(optimizer=SGD(lr=0.01, momentum=0.9), ae_weights=ae_weights) dec.model.summary() t0 = time() y_pred = dec.clustering(x, y=y, tol=tol, maxiter=maxiter, update_interval=update_interval, save_dir=save_dir) plot_model(dec.model, to_file='sdecmodel.png', show_shapes=True) print 'acc:', cluster_acc(y, y_pred) print 'clustering time: ', (time() - t0)