コード例 #1
0
def test(i, key, shape, rand=False, randFactor=256):
    global best_acc
    test_loss = 0
    correct = 0
    if (not rand) or (len(shape) != 4):
        model = nin.Net()
        pretrained_model = torch.load(args.pretrained)
        best_acc = pretrained_model['best_acc']
        model.load_state_dict(pretrained_model['state_dict'])
        model.to(device)
        bin_op = util.BinOp(model)
        model.eval()
        bin_op.binarization()
        state_dict = model.state_dict()

    if len(shape) == 4:
        size1 = shape[1]
        size2 = shape[2]
        size3 = shape[3]
        if rand:
            if (int(i / (size2 * size3)) % int(size1)) == torch.randint(
                    0, size1 - 1, [1]):
                model = nin.Net()
                pretrained_model = torch.load(args.pretrained)
                model.load_state_dict(pretrained_model['state_dict'])
                model.to(device)
                bin_op = util.BinOp(model)
                model.eval()
                bin_op.binarization()
                state_dict = model.state_dict()
                (state_dict[key][int(i / size1 / size2 / size3)][int(
                    i / size2 / size3 % size1)][int(i / size3 % size2)][int(
                        i % size3)]).mul_(-1)
            else:
                return 100
        else:
            (state_dict[key][int(i / size1 / size2 / size3)][int(
                i / size2 / size3 % size1)][int(i / size3 % size2)][int(
                    i % size3)]).mul_(-1)

    if len(shape) == 1:
        state_dict[key][i].mul_(-1)

    if len(shape) == 2:
        size = state_dict[key].shape[1]
        (state_dict[key][int(i / size)][i % size]).mul_(-1)

    with torch.no_grad():
        for data, target in testloader:
            data, target = Variable(data.to(device)), Variable(
                target.to(device))

            output = model(data)
            test_loss += criterion(output, target).data.item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    bin_op.restore()
    acc = 100. * float(correct) / len(testloader.dataset)
    return acc
コード例 #2
0
ファイル: main.py プロジェクト: yawudede/micronet
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers)  # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'],
                            a_bits=args.a_bits,
                            w_bits=args.w_bits)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'],
                               a_bits=args.a_bits,
                               w_bits=args.w_bits)
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
    else:
        print('******Initializing model******')
        if args.model_type == 0:
            model = nin.Net(a_bits=args.a_bits, w_bits=args.w_bits)
        else:
            model = nin_gc.Net(a_bits=args.a_bits, w_bits=args.w_bits)
        best_acc = 0
        for m in model.modules():
コード例 #3
0
            transforms.Normalize(
                (0.491399689874, 0.482158419622, 0.446530924224),
                (0.247032237587, 0.243485133253, 0.261587846975))
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)

    # define classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # define the model
    print('==> building model', args.arch, '...')
    if args.arch == 'nin':
        model = nin.Net()
    else:
        raise Exception(args.arch + ' is currently not supported')

    # initialize the model
    if not args.pretrained:
        print('==> Initializing model parameters ...')
        best_acc = 0
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.05)
                m.bias.data.zero_()
    else:
        print('==> Load pretrained model form', args.pretrained, '...')
        pretrained_model = torch.load(args.pretrained)
        best_acc = pretrained_model['best_acc']
コード例 #4
0
ファイル: bn_fuse.py プロジェクト: VincentSeven1/micronet
                        default=1,
                        help='model type:0-nin,1-nin_gc')
    parser.add_argument('--W', type=int, default=2, help='Wb:2, Wt:3, Wfp:32')
    parser.add_argument('--A', type=int, default=2, help='Ab:2, Afp:32')

    args = parser.parse_args()
    print('==> Options:', args)

    if args.gpu_id:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    if args.prune_quant:
        print('******Prune Quant model******')
        if args.model_type == 0:
            checkpoint = torch.load('../models_save/nin.pth')
            quant_model_train = nin.Net(cfg=checkpoint['cfg'])
        else:
            checkpoint = torch.load('../models_save/nin_gc.pth')
            quant_model_train = nin_gc.Net(cfg=checkpoint['cfg'])
    else:
        if args.model_type == 0:
            checkpoint = torch.load('../models_save/nin.pth')
            quant_model_train = nin.Net()
        else:
            checkpoint = torch.load('../models_save/nin_gc.pth')
            quant_model_train = nin_gc.Net()
    quant_bn_fused_model_inference = copy.deepcopy(quant_model_train)
    quantize.prepare(quant_model_train, inplace=True, A=args.A, W=args.W)
    quantize.prepare(quant_bn_fused_model_inference,
                     inplace=True,
                     A=args.A,
コード例 #5
0
    testset = torchvision.datasets.CIFAR10(root=args.data, train=False, download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size,
                                             shuffle=False, num_workers=args.num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    if args.prune_quant:
        print('******Prune Quant model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.prune_quant)
        cfg = checkpoint['cfg']
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'])
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'])
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
        print('***ori_model***\n', model)
        quantize.prepare(model, inplace=True, a_bits=args.a_bits,
                         w_bits=args.w_bits, q_type=args.q_type,
                         q_level=args.q_level, device=device,
                         weight_observer=args.weight_observer,
                         bn_fuse=args.bn_fuse,
                         bn_fuse_cali=args.bn_fuse_cali,
                         pretrained_model=args.pretrained_model,
                         qaft=args.qaft,
                         ptq=args.ptq,
                         percentile=args.percentile)
コード例 #6
0
                    help='path to save prune model (default: none)')
# 后续量化类型选择(三/二值、高位)
parser.add_argument('--quant_type',
                    type=int,
                    default=0,
                    help='quant_type:0-tnn_bin_model, 1-quant_model')
args = parser.parse_args()
base_number = args.normal_regular
layers = args.layers
print(args)

if base_number <= 0:
    print('\r\n!base_number is error!\r\n')
    base_number = 1

model = nin.Net(quant_type=args.quant_type)
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        model.load_state_dict(torch.load(args.model)['state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
print('旧模型: ', model)
total = 0
i = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1
            total += m.weight.data.shape[0]
コード例 #7
0
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = torchvision.datasets.CIFAR10(root = args.data, train = True, download = True, transform = transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers) # 训练集数据

    testset = torchvision.datasets.CIFAR10(root = args.data, train = False, download = True, transform = transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers) # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'], abits=args.Abits, wbits=args.Wbits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'], abits=args.Abits, wbits=args.Wbits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        model_dict = model.state_dict()
        update_state_dict = {k:v for k,v in checkpoint['state_dict'].items() if k in model_dict.keys()}  
        model_dict.update(update_state_dict)
        print('fp32_model weight load successfully')
        model.load_state_dict(model_dict)
        best_acc = 0
    else:
        print('******Initializing model******')
        if args.model_type == 0:
            model = nin.Net(abits=args.Abits, wbits=args.Wbits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        else:
            model = nin_gc.Net(abits=args.Abits, wbits=args.Wbits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        best_acc = 0
コード例 #8
0
ファイル: main.py プロジェクト: yawudede/micronet
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers)  # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'],
                            a_bits=args.a_bits,
                            w_bits=args.w_bits,
                            bn_fuse=args.bn_fuse,
                            q_type=args.q_type,
                            q_level=args.q_level,
                            device=device,
                            weight_observer=args.weight_observer)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'],
                               a_bits=args.a_bits,
                               w_bits=args.w_bits,
                               bn_fuse=args.bn_fuse,
                               q_type=args.q_type,
                               q_level=args.q_level,
                               device=device,
                               weight_observer=args.weight_observer)
        model_dict = model.state_dict()
        update_state_dict = {
            k: v
コード例 #9
0
        testset,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers)  # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'],
                            abits=args.Abits,
                            wbits=args.Wbits,
                            bn_fold=args.bn_fold,
                            q_type=args.q_type)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'],
                               abits=args.Abits,
                               wbits=args.Wbits,
                               bn_fold=args.bn_fold,
                               q_type=args.q_type)
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
    else:
        print('******Initializing model******')
        if args.model_type == 0:
            model = nin.Net(abits=args.Abits,
                            wbits=args.Wbits,
コード例 #10
0
        testset,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers)  # 测试集数据

    # cifar10类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # model
    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'], A=args.A, W=args.W)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'], A=args.A, W=args.W)
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
    else:
        print('******Initializing model******')
        # ******************** 在model的量化卷积中同时量化A(特征)和W(模型参数) ************************
        if args.model_type == 0:
            model = nin.Net(A=args.A, W=args.W)
        else:
            model = nin_gc.Net(A=args.A, W=args.W)
        #model = nin_bn_conv.Net()
        best_acc = 0
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
コード例 #11
0
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = torchvision.datasets.CIFAR10(root = args.data, train = True, download = True, transform = transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers) # 训练集数据

    testset = torchvision.datasets.CIFAR10(root = args.data, train = False, download = True, transform = transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers) # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'], a_bits=args.a_bits, w_bits=args.w_bits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'], a_bits=args.a_bits, w_bits=args.w_bits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        model_dict = model.state_dict()
        update_state_dict = {k:v for k,v in checkpoint['state_dict'].items() if k in model_dict.keys()}  
        model_dict.update(update_state_dict)
        print('fp32_model weight load successfully')
        model.load_state_dict(model_dict)
        best_acc = 0
    else:
        print('******Initializing model******')
        if args.model_type == 0:
            model = nin.Net(a_bits=args.a_bits, w_bits=args.w_bits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        else:
            model = nin_gc.Net(a_bits=args.a_bits, w_bits=args.w_bits, bn_fuse=args.bn_fuse, q_type=args.q_type, q_level=args.q_level)
        best_acc = 0
コード例 #12
0
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.eval_batch_size,
                                             shuffle=False,
                                             num_workers=2)

    # define classes
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('models_save/nin_prune.pth')
        checkpoint = torch.load(args.refine)
        cfg = checkpoint['cfg']
        model = nin.Net(cfg=checkpoint['cfg'], quant_type=args.quant_type)
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
    else:
        # nin_gc_retrain
        if args.gc_refine:
            print('******Refine model******')
            cfg = args.gc_refine
            model = nin_gc.Net(cfg=cfg, quant_type=args.quant_type)
        else:
            print('******Initializing model******')
            if args.model_type == 0:
                model = nin.Net(quant_type=args.quant_type)
            else:
                model = nin_gc.Net(quant_type=args.quant_type)
コード例 #13
0
    trainset = data.dataset(root=args.data, train=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
            shuffle=True, num_workers=6)

    testset = data.dataset(root=args.data, train=False)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128,
            shuffle=False, num_workers=6)

    # define classes
    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # define the model
    print('==> building model',args.arch,'...')
    if args.arch == 'nin':
        student = nin.Net()
        student.cuda()

    if args.netD == 'basic':
        netD = discriminator.netD()
        netD.cuda()

    if not args.pretrainedstudent:
        print('==> Initializing model parameters ...')
        best_acc = 0
        for m in student.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.05)
                m.bias.data.zero_()
        for d in netD.modules():
            if isinstance(m, nn.Conv2d):
コード例 #14
0
# 剪枝后保存的model
parser.add_argument('--save',
                    default='models_save/nin_prune.pth',
                    type=str,
                    metavar='PATH',
                    help='path to save prune model (default: none)')
args = parser.parse_args()
base_number = args.normal_regular
layers = args.layers
print(args)

if base_number <= 0:
    print('\r\n!base_number is error!\r\n')
    base_number = 1

model = nin.Net()
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        model.load_state_dict(torch.load(args.model)['state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
print('旧模型: ', model)
total = 0
i = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1
            total += m.weight.data.shape[0]
コード例 #15
0
def load_model(opt):
    if opt.pretrained_file != "":
        model = torch.load(opt.pretrained_file)
    else:
        if opt.model_def == 'alexnet':
            model = alexnet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'bincifar':
            model = bincifar.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'bincifarfbin':
            model = bincifarfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'densenet':
            model = densenet.DenseNet3(32, 10)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnetfbin':
            model = alexnetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnethybrid':
            model = alexnethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnethybridv2':
            model = alexnethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnetwbin':
            model = alexnetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenet':
            model = googlenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenetfbin':
            model = googlenetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenetwbin':
            model = googlenetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'mobilenet':
            model = mobilenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'nin':
            model = nin.Net()
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnet18':
            model = resnet.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnetfbin18':
            model = resnetfbin.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybrid18':
            model = resnethybrid.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybridv218':
            model = resnethybridv2.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybridv318':
            model = resnethybridv3.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnetwbin18':
            model = resnetwbin.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanet':
            model = sketchanet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanetfbin':
            model = sketchanetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanethybrid':
            model = sketchanethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanethybridv2':
            model = sketchanethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanetwbin':
            model = sketchanetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenet':
            model = squeezenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenetfbin':
            model = squeezenetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybrid':
            model = squeezenethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybridv2':
            model = squeezenethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybridv3':
            model = squeezenethybridv3.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenetwbin':
            model = squeezenetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'vgg16_bncifar':
            model = vgg.vgg16_bn()
            if opt.cuda:
                model = model.cuda()

    return model
コード例 #16
0
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = torchvision.datasets.CIFAR10(root = args.data, train = True, download = True, transform = transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers) # 训练集数据

    testset = torchvision.datasets.CIFAR10(root = args.data, train = False, download = True, transform = transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers) # 测试集数据

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if args.refine:
        print('******Refine model******')
        #checkpoint = torch.load('../prune/models_save/nin_refine.pth')
        checkpoint = torch.load(args.refine)
        if args.model_type == 0:
            model = nin.Net(cfg=checkpoint['cfg'], abits=args.Abits, wbits=args.Wbits)
        else:
            model = nin_gc.Net(cfg=checkpoint['cfg'], abits=args.Abits, wbits=args.Wbits)
        model.load_state_dict(checkpoint['state_dict'])
        best_acc = 0
    else:
        print('******Initializing model******')
        if args.model_type == 0:
            model = nin.Net(abits=args.Abits, wbits=args.Wbits)
        else:
            model = nin_gc.Net(abits=args.Abits, wbits=args.Wbits)
        best_acc = 0
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data)
                m.bias.data.zero_()
コード例 #17
0
parser.add_argument('--save',
                    default='models_save/nin_prune.pth',
                    type=str,
                    metavar='PATH',
                    help='path to save prune model (default: none)')
args = parser.parse_args()
base_number = args.normal_regular
layers = args.layers
print(args)

if base_number <= 0:
    print('\r\n!base_number is error!\r\n')
    base_number = 1

# 定义模型, 并导入参数!
model = nin.Net()
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        model.load_state_dict(torch.load(args.model)['state_dict'])
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
print('旧模型: ', model)
# ===================================================================

total = 0  # 所有BN层的channel之和
i = 0  # i 为 batchnorm 的统计层数
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1