示例#1
0
def load_model_from_name(x):
    opt = cifar10Manager.load_metadata(x, 0)[0]
    #small old bug in the saving of metadata, this is a cheap trick to remedy it
    for key, val in opt.items():
        if isinstance(val, str):
            opt[key] = eval(val)
    model = convForwModel.ConvolForwardNet(**opt)
    if USE_CUDA: model = model.cuda()
    model.load_state_dict(cifar10Manager.load_model_state_dict(x))
    return model
示例#2
0
TRAIN_TEACHER_MODEL = True
TRAIN_SMALLER_MODEL = False
TRAIN_SMALLER_QUANTIZED_MODEL = False
TRAIN_DISTILLED_MODEL = False
TRAIN_DIFFERENTIABLE_QUANTIZATION = False
CHECK_PM_QUANTIZATION = True

batch_size = 25
cifar10 = datasets.CIFAR10()
train_loader, test_loader = cifar10.getTrainLoader(batch_size), cifar10.getTestLoader(batch_size)

# Teacher model
model_name = 'cifar10_teacher'
teacherModelPath = os.path.join(cifar10modelsFolder, model_name)
teacherModel = convForwModel.ConvolForwardNet(**convForwModel.teacherModelSpec,
                                              useBatchNorm=USE_BATCH_NORM,
                                              useAffineTransformInBatchNorm=AFFINE_BATCH_NORM)
if USE_CUDA: teacherModel = teacherModel.cuda()
if not model_name in cifar10Manager.saved_models:
    cifar10Manager.add_new_model(model_name, teacherModelPath,
            arguments_creator_function={**convForwModel.teacherModelSpec,
                                        'useBatchNorm':USE_BATCH_NORM,
                                        'useAffineTransformInBatchNorm':AFFINE_BATCH_NORM})
if TRAIN_TEACHER_MODEL:
    cifar10Manager.train_model(teacherModel, model_name=model_name,
                               train_function=convForwModel.train_model,
                               arguments_train_function={'epochs_to_train': epochsToTrainCIFAR},
                               train_loader=train_loader, test_loader=test_loader)
teacherModel.load_state_dict(cifar10Manager.load_model_state_dict(model_name))
cnn_hf.evaluateModel(teacherModel, test_loader, k=5)
示例#3
0
文件: main.py 项目: Flamexmt/LMA
    except:
        pass

    epochsToTrainCIFAR = args.epochs
    USE_BATCH_NORM = True
    AFFINE_BATCH_NORM = True

    if args.data == 'cifar10':
        data = datasets.CIFAR10()
    elif args.data == 'cifar100':
        data = datasets.CIFAR100()
    if args.test_memory:
        if not args.train_teacher:
            model = convForwModel.ConvolForwardNet(
                **smallerModelSpecs[args.stModel],
                activation=args.stud_act,
                numBins=args.num_bins,
                useBatchNorm=USE_BATCH_NORM,
                useAffineTransformInBatchNorm=AFFINE_BATCH_NORM)
        else:
            model = convForwModel.ConvolForwardNet(
                **convForwModel.teacherModelSpec,
                useBatchNorm=USE_BATCH_NORM,
                useAffineTransformInBatchNorm=AFFINE_BATCH_NORM)
        if USE_CUDA: model = model.cuda()
        test_loader = data.getTestLoader(1)
        import time

        start = time.time()
        cnn_hf.evaluateModel(model, test_loader)
        mem = torch.cuda.max_memory_allocated()
        end = time.time()
datasets.BASE_DATA_FOLDER = 'data'

batch_size = 50
epochsToTrainCIFAR = 100
TRAIN_TEACHER_MODEL = False
TRAIN_DISTILLED_MODEL = True
TRAIN_SMALLER_MODEL = True
TRAIN_DISTILLED_QUANTIZED_MODEL = True

cifar10 = datasets.CIFAR10() #-> will be saved in /home/saved_datasets/cifar10
train_loader, test_loader = cifar10.getTrainLoader(batch_size), cifar10.getTestLoader(batch_size)

import cnn_models.conv_forward_model as convForwModel
import cnn_models.help_fun as cnn_hf
teacherModel = convForwModel.ConvolForwardNet(**convForwModel.teacherModelSpec, useBatchNorm=True, useAffineTransformInBatchNorm=True)
#convForwModel.train_model(teacherModel, train_loader, test_loader, epochs_to_train=20)

import cnn_models.conv_forward_model as convForwModel
import cnn_models.help_fun as cnn_hf
import model_manager
model_manager_path = 'model_manager_cifar10.tst'
model_save_path ='models'
__mkdir(model_save_path)

if os.path.exists(model_manager_path):
    cifar10Manager = model_manager.ModelManager('model_manager_cifar10.tst',
                                            'model_manager', create_new_model_manager=False)#the first t
else:
    cifar10Manager = model_manager.ModelManager('model_manager_cifar10.tst',
                                            'model_manager', create_new_model_manager=True)#the first t
示例#5
0
distilled_model_names = [
    'cifar10_distilled_spec{}'.format(idx_spec)
    for idx_spec in range(len(smallerModelSpecs))
]
for distilled_model_name in distilled_model_names:
    modelOptions = cifar10Manager.load_metadata(distilled_model_name, 0)[0]
    # small old bug in the saving of metadata, this is a cheap trick to remedy it
    for key, val in modelOptions.items():
        if isinstance(val, str):
            modelOptions[key] = eval(val)
    for numBit in numBits:
        if numBit == 8: continue
        distilled_quantized_model_name = distilled_model_name + '_quant_points_{}bits'.format(
            numBit)
        distilled_quantized_model = convForwModel.ConvolForwardNet(
            **modelOptions)
        if USE_CUDA:
            distilled_quantized_model = distilled_quantized_model.cuda()
        save_path = cifar10Manager.get_model_base_path(distilled_model_name) + \
                    'quant_points_{}bits'.format(numBit)

        with open(save_path, 'rb') as p:
            quantization_points, infoDict = pickle.load(p)
        distilled_quantized_model.load_state_dict(
            torch.load(save_path + '_model_state_dict'))

        quantization_functions = [
            functools.partial(quantization.nonUniformQuantization,
                              listQuantizationPoints=qp,
                              bucket_size=256) for qp in quantization_points
        ]