def load_model_from_name(x): opt = cifar10Manager.load_metadata(x, 0)[0] #small old bug in the saving of metadata, this is a cheap trick to remedy it for key, val in opt.items(): if isinstance(val, str): opt[key] = eval(val) model = convForwModel.ConvolForwardNet(**opt) if USE_CUDA: model = model.cuda() model.load_state_dict(cifar10Manager.load_model_state_dict(x)) return model
TRAIN_TEACHER_MODEL = True TRAIN_SMALLER_MODEL = False TRAIN_SMALLER_QUANTIZED_MODEL = False TRAIN_DISTILLED_MODEL = False TRAIN_DIFFERENTIABLE_QUANTIZATION = False CHECK_PM_QUANTIZATION = True batch_size = 25 cifar10 = datasets.CIFAR10() train_loader, test_loader = cifar10.getTrainLoader(batch_size), cifar10.getTestLoader(batch_size) # Teacher model model_name = 'cifar10_teacher' teacherModelPath = os.path.join(cifar10modelsFolder, model_name) teacherModel = convForwModel.ConvolForwardNet(**convForwModel.teacherModelSpec, useBatchNorm=USE_BATCH_NORM, useAffineTransformInBatchNorm=AFFINE_BATCH_NORM) if USE_CUDA: teacherModel = teacherModel.cuda() if not model_name in cifar10Manager.saved_models: cifar10Manager.add_new_model(model_name, teacherModelPath, arguments_creator_function={**convForwModel.teacherModelSpec, 'useBatchNorm':USE_BATCH_NORM, 'useAffineTransformInBatchNorm':AFFINE_BATCH_NORM}) if TRAIN_TEACHER_MODEL: cifar10Manager.train_model(teacherModel, model_name=model_name, train_function=convForwModel.train_model, arguments_train_function={'epochs_to_train': epochsToTrainCIFAR}, train_loader=train_loader, test_loader=test_loader) teacherModel.load_state_dict(cifar10Manager.load_model_state_dict(model_name)) cnn_hf.evaluateModel(teacherModel, test_loader, k=5)
except: pass epochsToTrainCIFAR = args.epochs USE_BATCH_NORM = True AFFINE_BATCH_NORM = True if args.data == 'cifar10': data = datasets.CIFAR10() elif args.data == 'cifar100': data = datasets.CIFAR100() if args.test_memory: if not args.train_teacher: model = convForwModel.ConvolForwardNet( **smallerModelSpecs[args.stModel], activation=args.stud_act, numBins=args.num_bins, useBatchNorm=USE_BATCH_NORM, useAffineTransformInBatchNorm=AFFINE_BATCH_NORM) else: model = convForwModel.ConvolForwardNet( **convForwModel.teacherModelSpec, useBatchNorm=USE_BATCH_NORM, useAffineTransformInBatchNorm=AFFINE_BATCH_NORM) if USE_CUDA: model = model.cuda() test_loader = data.getTestLoader(1) import time start = time.time() cnn_hf.evaluateModel(model, test_loader) mem = torch.cuda.max_memory_allocated() end = time.time()
datasets.BASE_DATA_FOLDER = 'data' batch_size = 50 epochsToTrainCIFAR = 100 TRAIN_TEACHER_MODEL = False TRAIN_DISTILLED_MODEL = True TRAIN_SMALLER_MODEL = True TRAIN_DISTILLED_QUANTIZED_MODEL = True cifar10 = datasets.CIFAR10() #-> will be saved in /home/saved_datasets/cifar10 train_loader, test_loader = cifar10.getTrainLoader(batch_size), cifar10.getTestLoader(batch_size) import cnn_models.conv_forward_model as convForwModel import cnn_models.help_fun as cnn_hf teacherModel = convForwModel.ConvolForwardNet(**convForwModel.teacherModelSpec, useBatchNorm=True, useAffineTransformInBatchNorm=True) #convForwModel.train_model(teacherModel, train_loader, test_loader, epochs_to_train=20) import cnn_models.conv_forward_model as convForwModel import cnn_models.help_fun as cnn_hf import model_manager model_manager_path = 'model_manager_cifar10.tst' model_save_path ='models' __mkdir(model_save_path) if os.path.exists(model_manager_path): cifar10Manager = model_manager.ModelManager('model_manager_cifar10.tst', 'model_manager', create_new_model_manager=False)#the first t else: cifar10Manager = model_manager.ModelManager('model_manager_cifar10.tst', 'model_manager', create_new_model_manager=True)#the first t
distilled_model_names = [ 'cifar10_distilled_spec{}'.format(idx_spec) for idx_spec in range(len(smallerModelSpecs)) ] for distilled_model_name in distilled_model_names: modelOptions = cifar10Manager.load_metadata(distilled_model_name, 0)[0] # small old bug in the saving of metadata, this is a cheap trick to remedy it for key, val in modelOptions.items(): if isinstance(val, str): modelOptions[key] = eval(val) for numBit in numBits: if numBit == 8: continue distilled_quantized_model_name = distilled_model_name + '_quant_points_{}bits'.format( numBit) distilled_quantized_model = convForwModel.ConvolForwardNet( **modelOptions) if USE_CUDA: distilled_quantized_model = distilled_quantized_model.cuda() save_path = cifar10Manager.get_model_base_path(distilled_model_name) + \ 'quant_points_{}bits'.format(numBit) with open(save_path, 'rb') as p: quantization_points, infoDict = pickle.load(p) distilled_quantized_model.load_state_dict( torch.load(save_path + '_model_state_dict')) quantization_functions = [ functools.partial(quantization.nonUniformQuantization, listQuantizationPoints=qp, bucket_size=256) for qp in quantization_points ]