def train_model(model_name, data_size, training_epochs, dset): if (data_size == 'search'): train_dataset = search_train_data(dset=dset) if (data_size == 'full'): train_dataset = train_data(dset=dset) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_data_full = Variable(test_data(fetch='data', dset=dset)).cuda() test_labels_full = Variable(test_data(fetch='labels', dset=dset)).cuda() if model_name == "SWSModel": model = model_archs.SWSModel().cuda() elif model_name == "LeNet5": model = model_archs.LeNet5().cuda() else: model = model_archs.LeNet_300_100().cuda() print("Model Name: {} Epochs: {} Data: {}".format(model.name, training_epochs, data_size)) print_dims(model) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0001) for epoch in range(training_epochs): model, loss = train_epoch(model, optimizer, criterion, train_loader) if (trueAfterN(epoch, 10)): test_acc = test_accuracy(test_data_full, test_labels_full, model) print('Epoch: {}. Test Accuracy: {:.2f}'.format( epoch + 1, test_acc[0]))
from retrain_model import retrain_model savedir = os.getcwd() + "/models/" import copy import pickle from mnist_loader import train_data from utils_sws import sws_prune, compressed_model from utils_model import test_accuracy, layer_accuracy, sws_replace import torch from torch.autograd import Variable from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size from utils_misc import model_load_dir if __name__ == "__main__": test_data_full = Variable(test_data(fetch="data")).cuda() test_labels_full = Variable(test_data(fetch="labels")).cuda() val_data_full = Variable(search_validation_data(fetch="data")).cuda() val_labels_full = Variable(search_validation_data(fetch="labels")).cuda() parser = argparse.ArgumentParser() parser.add_argument('--start', dest="start", help="Start Search", required=True, type=(int)) parser.add_argument('--end', dest="end", help="End Search", required=True, type=(int)) args = parser.parse_args()
def __init__(self, init_model, gmp="", mode="retrain", full_model="", data_size='search', loss_type='CE', mv=(0, 0), zmv=(0, 0), tau=1, temp=0, mixtures=1, dset="mnist"): self.layers = [ x.replace(".weight", "") for x in init_model.state_dict().keys() if "weight" in x ] self.layer_init_weights = {} for l in self.layers: self.layer_init_weights[l] = np.concatenate([ init_model.state_dict()[l + ".weight"].clone().view( -1).cpu().numpy(), init_model.state_dict()[l + ".bias"].clone().view( -1).cpu().numpy() ]) self.layer_weights = {} self.mode = mode self.loss_type = loss_type #accuracy and lost tracking flags self.data_size = data_size #accuracy and loss history self.epochs = [] self.train_accuracy = [] self.test_accuracy = [] self.val_accuracy = [] self.train_loss = [] self.test_loss = [] self.val_loss = [] self.complexity_loss = [] if (mode == 'layer_retrain'): self.full_model = full_model self.prune_layer_weight = {} self.prune_acc = {} self.sparsity = 0 self.mean = mv[0] self.var = mv[1] self.zmean = zmv[0] self.zvar = zmv[1] self.tau = tau self.temp = temp self.mixtures = mixtures self.use_prune = False #gmp tracking self.use_gmp = ((mode == 'retrain' or mode == 'layer_retrain') and gmp != "") if (self.use_gmp): self.gmp_stddev = np.sqrt( 1. / gmp.gammas.exp().data.clone().cpu().numpy()) self.gmp_means = gmp.means.data.clone().cpu().numpy() self.gmp_mixprop = gmp.rhos.exp().data.clone().cpu().numpy() self.gmp_scale = gmp.scale.exp().data.clone().cpu().numpy() self.test_data_full = Variable(test_data(fetch='data', dset=dset)).cuda() self.test_labels_full = Variable(test_data(fetch='labels', dset=dset)).cuda() if (data_size == 'search'): self.val_data_full = Variable( train_data(fetch='data', dset=dset)[50000:60000]).cuda() self.val_labels_full = Variable( train_data(fetch='labels', dset=dset)[50000:60000]).cuda() self.train_data_full = Variable( train_data(fetch='data', dset=dset)[40000:50000]).cuda() self.train_labels_full = Variable( train_data(fetch='labels', dset=dset)[40000:50000]).cuda() else: self.train_data_full = Variable(train_data(fetch='data', dset=dset)).cuda() self.train_labels_full = Variable( train_data(fetch='labels', dset=dset)).cuda()
def retrain_model(mean, var, zmean, zvar, tau, temp, mixtures, model_name, data_size, loss_type = 'MSESNT', scaling = False, model_save_dir = "", fn="", dset="mnist"): ab = get_ab(mean, var) zab = get_ab(zmean, zvar) if(data_size == 'search'): train_dataset = search_retrain_data val_data_full = Variable(search_validation_data(fetch='data', dset=dset)).cuda() val_labels_full = Variable(search_validation_data(fetch='labels', dset=dset)).cuda() (x_start, x_end) = (40000, 50000) if(data_size == 'full'): train_dataset = train_data (x_start, x_end) = (0, 60000) test_data_full = Variable(test_data(fetch='data', dset=dset)).cuda() test_labels_full = Variable(test_data(fetch='labels', dset=dset)).cuda() model_file = '{}_{}_{}_{}'.format(dset, model_name, 100, data_size) model = torch.load(model_load_dir + model_file + '.m').cuda() if temp == 0: loader = torch.utils.data.DataLoader(dataset=train_dataset(onehot = (loss_type == 'MSESNT'), dset=dset), batch_size=batch_size, shuffle=True) else: output = torch.load("{}{}_targets/{}.out.m".format(model_load_dir, model_file, "fc3" if "300_100" in model.name else "fc2"))[x_start:x_end]### output = (nn.Softmax(dim=1)(output/temp)).data dataset = torch.utils.data.TensorDataset(train_dataset(fetch='data', dset=dset), output) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) criterion = nn.CrossEntropyLoss()### s = "s" if scaling else "f" exp_name = "{}_m{}_zm{}_r{}_t{}_m{}_kdT{}_{}_{}".format(model.name, mean, zmean, retraining_epochs, tau, int(mixtures), int(temp), s, data_size) + fn gmp = GaussianMixturePrior(mixtures, [x for x in model.parameters()], 0.99, zero_ab = zab, ab = ab, scaling = scaling) gmp.print_batch = False mlr = 0.5e-4 if scaling else 0.5e-4 optimizable_params = [ {'params': model.parameters(), 'lr': 2e-4}, {'params': [gmp.means], 'lr': mlr}, {'params': [gmp.gammas, gmp.rhos], 'lr': 3e-3}] if (scaling): if "SWS" in model_name: slr = 2e-5 else: slr = 1e-6 optimizable_params = optimizable_params + [{'params': gmp.scale, 'lr': slr}] opt = torch.optim.Adam(optimizable_params)#log precisions and mixing proportions res_stats = plot_data(init_model = model, gmp = gmp, mode = 'retrain', data_size = data_size, loss_type='CE', mv = (mean, var), zmv = (zmean, zvar), tau = tau, temp = temp, mixtures = mixtures, dset = dset) s_hist = [] a_hist = [] for epoch in range(retraining_epochs): ### [ACT DISABLE LR] #if(scaling and epoch == 0): # opt.param_groups[3]['lr'] = 0 # print ("Scaling Disabled - Epoch {}".format(epoch)) model, loss = retrain_sws_epoch(model, gmp, opt, loader, tau, temp, loss_type) res_stats.data_epoch(epoch + 1, model, gmp) if (trueAfterN(epoch, 10)): #test_acc = test_accuracy(test_data_full, test_labels_full, model) nm = sws_prune_copy(model, gmp) s = get_sparsity(nm) a = test_accuracy(test_data_full, test_labels_full, nm)[0] print('Epoch: {}. Test Accuracy: {:.2f}, Prune Accuracy: {:.2f}, Sparsity: {:.2f}'.format(epoch+1, res_stats.test_accuracy[-1], a, s)) #show_sws_weights(model = model, means = list(gmp.means.data.clone().cpu()), precisions = list(gmp.gammas.data.clone().cpu()), epoch = epoch)### nm = sws_prune_copy(model, gmp) s = get_sparsity(nm) a = test_accuracy(test_data_full, test_labels_full, nm)[0] s_hist.append(s) a_hist.append(a) if (data_size == 'search' and (epoch>12) and trueAfterN(epoch, 2)): val_acc = res_stats.test_accuracy[-1] if (val_acc < 50.0): print ("Terminating Search - Epoch: {} - Val Acc: {:.2f}".format(epoch, val_acc)) break res = res_stats.gen_dict() model_prune = sws_prune_copy(model, gmp) res_stats.data_prune(model_prune) res = res_stats.gen_dict() res['test_prune_acc'] = a_hist res['test_prune_sp'] = s_hist cm = compressed_model(model_prune.state_dict(), [gmp]) res['cm'] = cm.get_cr_list() if (data_size == "search"): print('Retrain Test: {:.2f}, Retrain Validation: {:.2f}, Prune Test: {:.2f}, Prune Validation: {:.2f}, Prune Sparsity: {:.2f}' .format(res['test_acc'][-1], res['val_acc'][-1], res['prune_acc']['test'], res['prune_acc']['val'], res['sparsity'])) else: print('Retrain Test: {:.2f}, Prune Test: {:.2f}, Prune Sparsity: {:.2f}'.format(res['test_acc'][-1], res['prune_acc']['test'],res['sparsity'])) if(model_save_dir!=""): torch.save(model, model_save_dir + '/{}_retrain_model_{}.m'.format(dset, exp_name)) with open(model_save_dir + '/{}_retrain_gmp_{}.p'.format(dset, exp_name),'wb') as f: pickle.dump(gmp, f) with open(model_save_dir + '/{}_retrain_res_{}.p'.format(dset, exp_name),'wb') as f: pickle.dump(res, f) return model, gmp, res