Пример #1
0
def main():
    global args, iters
    global file
    args = parser.parse_args()

    args.gpu = [int(i) for i in args.gpu.split(',')]
    torch.cuda.set_device(args.gpu[0] if args.gpu else None)
    torch.backends.cudnn.benchmark = True
    L_cls_f = nn.CrossEntropyLoss().cuda()

    ## Dataset Loading
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./cifar10',
        train=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(), normalize
        ]),
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./cifar10',
        train=False,
        transform=transforms.Compose([transforms.ToTensor(), normalize])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    ## Model Initialize and Loading
    model = resnet.resnet32()
    args.checkpoint = 'models/model_32.th'

    model = nn.DataParallel(model, device_ids=args.gpu).cuda()
    checkpoint = torch.load(args.checkpoint, map_location='cuda:0')
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    original_model = copy.deepcopy(model)

    loss, init_acc = validate(val_loader, model, L_cls_f, '')
    print('\nOriginal performance. Acc: {:2.2f}%'.format(init_acc))

    num_params = get_params(model)
    num_flops = get_flops(model)

    ## 1. Initialization process
    layer_names = get_layer_names(model)
    expected_flops(copy.deepcopy(model), layer_names[1:], num_params,
                   num_flops, args.ratio)
    add_1x1_convs(model, layer_names[1:])

    print('== 1. Initialization fine-tuning stage. ')
    model_opt = torch.optim.SGD(model.parameters(),
                                lr=1e-3,
                                momentum=0.9,
                                weight_decay=1e-4)
    for epochs in range(20):
        fine_tuning(model, original_model, train_loader, val_loader, L_cls_f,
                    model_opt, False)
        loss, acc = validate(val_loader, model, L_cls_f, '* ')

        print("[Init {:02d}] Loss: {:.3f}. Acc: {:2.2f}%. || Param: {:2.2f}%  Flop: {:2.2f}%".format(epochs+1, loss, acc, \
         get_params(model)/num_params*100, get_flops(model)/num_flops*100))

    ## 2. Pruning process, from top to bottom
    print('\n== 2. Pruning stage. ')
    for i in range(1, len(layer_names)):
        index = len(layer_names) - i
        model = pruning_output_channel(model, original_model,
                                       layer_names[index], train_loader,
                                       val_loader, L_cls_f)
        model = pruning_input_channel(model, original_model,
                                      layer_names[index], train_loader,
                                      val_loader, L_cls_f)

        loss, acc = validate(val_loader, model, L_cls_f, '* ')
        print("[Pruning {:02d}]. Loss: {:.3f}. Acc: {:2.2f}%. || Param: {:2.2f}%  Flop: {:2.2f}%".format(index, loss, \
         acc, get_params(model)/num_params*100, get_flops(model)/num_flops*100))

    ## 3. Final Fine-tuning stage
    print('\n==3. Final fine-tuning stage after pruning.')
    best_acc = 0
    model_opt = torch.optim.SGD(model.parameters(),
                                lr=1e-2,
                                momentum=0.9,
                                weight_decay=1e-4)

    for epochs in range(args.step_ft):
        adjust_learning_rate(model_opt, epochs, args.step_ft)
        fine_tuning(model, original_model, train_loader, val_loader, L_cls_f,
                    model_opt)
        loss, acc = validate(val_loader, model, L_cls_f, '* ')
        if acc > best_acc: best_acc = acc

        print("[Fine-tune {:03d}] Loss: {:.3f}. Acc: {:2.2f}%. || Param: {:2.2f}%  Flop: {:2.2f}%  Best: {:2.2f}%".format(epochs+1, loss, acc, \
         get_params(model)/num_params*100, get_flops(model)/num_flops*100, best_acc))

    print("\n[Final] Baseline: {:2.2f}%. After Pruning: {:2.2f}%. || Diff: {:2.2f}%  Param: {:2.2f}%  Flop: {:2.2f}%".format(\
      init_acc, best_acc, init_acc - best_acc, get_params(model)\
      /num_params*100, get_flops(model)/num_flops*100))
Пример #2
0
from model import SequentialImageNetwork, SequentialImageNetworkMod
from util import *
from datasets import *
import re

import sys

name = sys.argv[1]

model_flag = name.split("-")[0]
train_flag = name.split("-")[1]

if model_flag == "r32p":
    import resnet

    model = SequentialImageNetworkMod(resnet.resnet32()).cuda()
elif model_flag == "r18":
    from pytorch_cifar.models import resnet

    model = SequentialImageNetwork(resnet.ResNet18()).cuda()
else:
    raise NotImplementedError

eps = int(re.search(r"[0-9]+$", name).group())
poisoner_flag = name.split("-")[3][:3]
clean_label = int(name.split("-")[2][0])
target_label = int(name.split("-")[2][1])

print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=} {eps=}")

if len(sys.argv) > 2:
Пример #3
0
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision
from torch.autograd import Variable
from resnet import resnet32
from collections import OrderedDict
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from custom_loader import val_set
# %%
device = "cuda"
model = resnet32().to(device)
# print(model)


train_loader = torch.utils.data.DataLoader(val_set,
    batch_size=128, shuffle=True,
    num_workers=4, pin_memory=True)

# %%
saved_model_path = '/root/Adversarial-attacks-DNN-18786/saved_model/resnet32-adv'
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), 0.1,
                            momentum=0.9,
                            weight_decay=1e-4)