Beispiel #1
0
def train_model(model_name, data_size, training_epochs, dset):
    if (data_size == 'search'):
        train_dataset = search_train_data(dset=dset)
    if (data_size == 'full'):
        train_dataset = train_data(dset=dset)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_data_full = Variable(test_data(fetch='data', dset=dset)).cuda()
    test_labels_full = Variable(test_data(fetch='labels', dset=dset)).cuda()

    if model_name == "SWSModel":
        model = model_archs.SWSModel().cuda()
    elif model_name == "LeNet5":
        model = model_archs.LeNet5().cuda()
    else:
        model = model_archs.LeNet_300_100().cuda()

    print("Model Name: {} Epochs: {} Data: {}".format(model.name,
                                                      training_epochs,
                                                      data_size))
    print_dims(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-3,
                                 weight_decay=0.0001)

    for epoch in range(training_epochs):
        model, loss = train_epoch(model, optimizer, criterion, train_loader)

        if (trueAfterN(epoch, 10)):
            test_acc = test_accuracy(test_data_full, test_labels_full, model)
            print('Epoch: {}. Test Accuracy: {:.2f}'.format(
                epoch + 1, test_acc[0]))
Beispiel #2
0
from retrain_model import retrain_model

savedir = os.getcwd() + "/models/"
import copy

import pickle
from mnist_loader import train_data
from utils_sws import sws_prune, compressed_model
from utils_model import test_accuracy, layer_accuracy, sws_replace
import torch
from torch.autograd import Variable
from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size
from utils_misc import model_load_dir

if __name__ == "__main__":
    test_data_full = Variable(test_data(fetch="data")).cuda()
    test_labels_full = Variable(test_data(fetch="labels")).cuda()
    val_data_full = Variable(search_validation_data(fetch="data")).cuda()
    val_labels_full = Variable(search_validation_data(fetch="labels")).cuda()
    parser = argparse.ArgumentParser()
    parser.add_argument('--start',
                        dest="start",
                        help="Start Search",
                        required=True,
                        type=(int))
    parser.add_argument('--end',
                        dest="end",
                        help="End Search",
                        required=True,
                        type=(int))
    args = parser.parse_args()
Beispiel #3
0
    def __init__(self,
                 init_model,
                 gmp="",
                 mode="retrain",
                 full_model="",
                 data_size='search',
                 loss_type='CE',
                 mv=(0, 0),
                 zmv=(0, 0),
                 tau=1,
                 temp=0,
                 mixtures=1,
                 dset="mnist"):
        self.layers = [
            x.replace(".weight", "") for x in init_model.state_dict().keys()
            if "weight" in x
        ]
        self.layer_init_weights = {}
        for l in self.layers:
            self.layer_init_weights[l] = np.concatenate([
                init_model.state_dict()[l + ".weight"].clone().view(
                    -1).cpu().numpy(),
                init_model.state_dict()[l + ".bias"].clone().view(
                    -1).cpu().numpy()
            ])
        self.layer_weights = {}

        self.mode = mode
        self.loss_type = loss_type

        #accuracy and lost tracking flags
        self.data_size = data_size

        #accuracy and loss history
        self.epochs = []
        self.train_accuracy = []
        self.test_accuracy = []
        self.val_accuracy = []
        self.train_loss = []
        self.test_loss = []
        self.val_loss = []
        self.complexity_loss = []

        if (mode == 'layer_retrain'):
            self.full_model = full_model

        self.prune_layer_weight = {}
        self.prune_acc = {}
        self.sparsity = 0

        self.mean = mv[0]
        self.var = mv[1]
        self.zmean = zmv[0]
        self.zvar = zmv[1]
        self.tau = tau
        self.temp = temp
        self.mixtures = mixtures
        self.use_prune = False

        #gmp tracking
        self.use_gmp = ((mode == 'retrain' or mode == 'layer_retrain')
                        and gmp != "")
        if (self.use_gmp):
            self.gmp_stddev = np.sqrt(
                1. / gmp.gammas.exp().data.clone().cpu().numpy())
            self.gmp_means = gmp.means.data.clone().cpu().numpy()
            self.gmp_mixprop = gmp.rhos.exp().data.clone().cpu().numpy()
            self.gmp_scale = gmp.scale.exp().data.clone().cpu().numpy()

        self.test_data_full = Variable(test_data(fetch='data',
                                                 dset=dset)).cuda()
        self.test_labels_full = Variable(test_data(fetch='labels',
                                                   dset=dset)).cuda()

        if (data_size == 'search'):
            self.val_data_full = Variable(
                train_data(fetch='data', dset=dset)[50000:60000]).cuda()
            self.val_labels_full = Variable(
                train_data(fetch='labels', dset=dset)[50000:60000]).cuda()
            self.train_data_full = Variable(
                train_data(fetch='data', dset=dset)[40000:50000]).cuda()
            self.train_labels_full = Variable(
                train_data(fetch='labels', dset=dset)[40000:50000]).cuda()

        else:
            self.train_data_full = Variable(train_data(fetch='data',
                                                       dset=dset)).cuda()
            self.train_labels_full = Variable(
                train_data(fetch='labels', dset=dset)).cuda()
Beispiel #4
0
def retrain_model(mean, var, zmean, zvar, tau, temp, mixtures, model_name, data_size, loss_type = 'MSESNT', scaling = False, model_save_dir = "",  fn="", dset="mnist"):
	ab = get_ab(mean, var)
	zab = get_ab(zmean, zvar)

	if(data_size == 'search'):
		train_dataset = search_retrain_data
		val_data_full = Variable(search_validation_data(fetch='data', dset=dset)).cuda()
		val_labels_full = Variable(search_validation_data(fetch='labels', dset=dset)).cuda()
		(x_start, x_end) = (40000, 50000)
	if(data_size == 'full'):
		train_dataset = train_data
		(x_start, x_end) = (0, 60000)
	test_data_full = Variable(test_data(fetch='data', dset=dset)).cuda()
	test_labels_full = Variable(test_data(fetch='labels', dset=dset)).cuda()
		
	model_file = '{}_{}_{}_{}'.format(dset, model_name, 100, data_size)
	model = torch.load(model_load_dir + model_file + '.m').cuda()
		
	if temp == 0:
		loader = torch.utils.data.DataLoader(dataset=train_dataset(onehot = (loss_type == 'MSESNT'), dset=dset), batch_size=batch_size, shuffle=True)
	else:
		output = torch.load("{}{}_targets/{}.out.m".format(model_load_dir, model_file, "fc3" if "300_100" in model.name else "fc2"))[x_start:x_end]###
		output = (nn.Softmax(dim=1)(output/temp)).data
		dataset = torch.utils.data.TensorDataset(train_dataset(fetch='data', dset=dset), output)
		loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
	criterion = nn.CrossEntropyLoss()###
	s = "s" if scaling else "f" 
	exp_name = "{}_m{}_zm{}_r{}_t{}_m{}_kdT{}_{}_{}".format(model.name, mean, zmean, retraining_epochs, tau, int(mixtures), int(temp), s, data_size) + fn
	gmp = GaussianMixturePrior(mixtures, [x for x in model.parameters()], 0.99, zero_ab = zab, ab = ab, scaling = scaling)
	gmp.print_batch = False

	mlr = 0.5e-4 if scaling else 0.5e-4

	optimizable_params = [
		{'params': model.parameters(), 'lr': 2e-4},
		{'params': [gmp.means], 'lr': mlr},
		{'params': [gmp.gammas, gmp.rhos], 'lr': 3e-3}]
	if (scaling):
		if "SWS" in model_name:
			slr = 2e-5
		else:
			slr = 1e-6
		optimizable_params = optimizable_params + [{'params': gmp.scale, 'lr': slr}]
	
	opt = torch.optim.Adam(optimizable_params)#log precisions and mixing proportions

	res_stats = plot_data(init_model = model, gmp = gmp, mode = 'retrain', data_size = data_size, loss_type='CE', mv = (mean, var), zmv = (zmean, zvar), tau = tau, temp = temp, mixtures = mixtures, dset = dset)
	s_hist = []
	a_hist = []
	for epoch in range(retraining_epochs):
		### [ACT DISABLE LR]
		#if(scaling and epoch == 0):
		#	opt.param_groups[3]['lr'] = 0
		#	print ("Scaling Disabled - Epoch {}".format(epoch))
		model, loss = retrain_sws_epoch(model, gmp, opt, loader, tau, temp, loss_type)
		res_stats.data_epoch(epoch + 1, model, gmp)
		

		if (trueAfterN(epoch, 10)):
			#test_acc = test_accuracy(test_data_full, test_labels_full, model)
			nm = sws_prune_copy(model, gmp)
			s = get_sparsity(nm)
			a = test_accuracy(test_data_full, test_labels_full, nm)[0]

			print('Epoch: {}. Test Accuracy: {:.2f}, Prune Accuracy: {:.2f}, Sparsity: {:.2f}'.format(epoch+1, res_stats.test_accuracy[-1], a, s))
			#show_sws_weights(model = model, means = list(gmp.means.data.clone().cpu()), precisions = list(gmp.gammas.data.clone().cpu()), epoch = epoch)###
		nm = sws_prune_copy(model, gmp)
		s = get_sparsity(nm)
		a = test_accuracy(test_data_full, test_labels_full, nm)[0]
		s_hist.append(s)
		a_hist.append(a)

		if (data_size == 'search' and (epoch>12) and trueAfterN(epoch, 2)):
			val_acc =  res_stats.test_accuracy[-1]
			if (val_acc < 50.0):
				print ("Terminating Search - Epoch: {} - Val Acc: {:.2f}".format(epoch, val_acc))
				break
	
	res = res_stats.gen_dict()
	
	model_prune = sws_prune_copy(model, gmp)

	res_stats.data_prune(model_prune)
	res = res_stats.gen_dict()
	res['test_prune_acc'] = a_hist
	res['test_prune_sp'] = s_hist
	cm = compressed_model(model_prune.state_dict(), [gmp])
	res['cm'] = cm.get_cr_list()

	if (data_size == "search"):
		print('Retrain Test: {:.2f}, Retrain Validation: {:.2f}, Prune Test: {:.2f}, Prune Validation: {:.2f}, Prune Sparsity: {:.2f}'
		.format(res['test_acc'][-1], res['val_acc'][-1], res['prune_acc']['test'], res['prune_acc']['val'], res['sparsity']))
	else:
		print('Retrain Test: {:.2f}, Prune Test: {:.2f}, Prune Sparsity: {:.2f}'.format(res['test_acc'][-1], res['prune_acc']['test'],res['sparsity']))

	if(model_save_dir!=""):
		torch.save(model, model_save_dir + '/{}_retrain_model_{}.m'.format(dset, exp_name))
		with open(model_save_dir + '/{}_retrain_gmp_{}.p'.format(dset, exp_name),'wb') as f:
			pickle.dump(gmp, f)
		with open(model_save_dir + '/{}_retrain_res_{}.p'.format(dset, exp_name),'wb') as f:
			pickle.dump(res, f)

	return model, gmp, res