Exemplo n.º 1
0
def pre_train(epoch,batch_size,learning_rate,weight_decay,momentum):
    start_epoch=0
    use_cuda=False
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # read data
    # transform lets you apply transform function to each image and create a pipeline of sequence of operations
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        # transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        # transforms.RandomCrop(64, padding=4),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # calls the TinyImageNet init function that reads data from the training or val directory of dataset
    train_set = TinyImageNet(root='./data', train=True, transform=transform_train,download=False)


    #Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=2)


    # calls the TinyImageNet init function that reads data from the training or val directory of dataset
    test_set = TinyImageNet(root='./data', train=False, transform=transform_test,download=False)

    #Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=50, shuffle=True, num_workers=2)


    print('==> creating model..')
    net = VGG('VGG3')

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),learning_rate, momentum, weight_decay)


    print( "trainset num = {}".format(len(train_set)) )

    print( "trainset num = {}".format(len(test_set)) )

    for epoch in range(0, epoch):
      train(epoch,net,criterion,optimizer,train_loader)
Exemplo n.º 2
0
# model.to(gpu_device)

if pretrain_weight:
    if os.path.isfile(pretrain_weight):
        print("=> loading checkpoint '{}'".format(pretrain_weight))
        checkpoint = torch.load(pretrain_weight)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(pretrain_weight, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.model))

origin_model_acc = best_prec1
ori_model_parameters = sum([param.nelement() for param in model.parameters()])

# origin model calc time
random_input = torch.rand((1, 3, 32, 32)).to(cpu_device)
model.to(cpu_device)
origin_forward_time, origin_flops, origin_params = calc_time_and_flops(random_input, model)
model.to(gpu_device)

#######################
# pre process
#######################
# determine prune mask
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]
Exemplo n.º 3
0
parser.add_argument('--a', default=0.4, type=float)
parser.add_argument('--c', default=500, type=int)
parser.add_argument('--num', default=100, type=int)
parser.add_argument('--i', default=1, type=int)
# Increase models diversity
parser.add_argument('--level', type=str, default='filter')
args = parser.parse_args()
print(args)
print('Session:%s\tModel:%d\tPID:%d' % (args.s, args.i, os.getpid()))
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = VGG(deep=16, n_classes=args.cifar).to(args.device)
last_net = VGG(deep=16, n_classes=args.cifar).to(args.device)
start_epoch = 0
accuracy = []
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Data
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
if args.cifar == 10:
    trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)
elif args.cifar == 100:
Exemplo n.º 4
0
def setup_and_prune(cmd_args,
                    params,
                    main_logger,
                    *,
                    prune_type="single",
                    model_name=None):
    """compress a model

    cmd_args: result from argparse
    params: parameters used to build a pruner

    """
    assert prune_type in ["single", "multiple"]
    assert model_name in ["vgg16", "resnet50", "resnet56"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, test_loader = get_data_loaders(cmd_args)
    if model_name == "vgg16":
        model = VGG("VGG16").to(device=device)
    elif model_name == "resnet50":
        model = ResNet50().to(device=device)
    elif model_name == "resnet56":
        model = resnet56().to(device=device)
    else:
        raise ValueError(f"Model {model_name} is wrong.")

    main_logger.info("Loading pretrained model {} ...".format(
        cmd_args.pretrained))
    try:
        model.load_state_dict(torch.load(cmd_args.pretrained))
    except FileNotFoundError:
        print("Pretrained model doesn't exist")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)

    main_logger.info("Testing pretrained model...")
    test(cmd_args, model, device, test_loader, main_logger)

    main_logger.info("start model pruning...")

    if isinstance(model, nn.DataParallel):
        model = model.module

    optimizer_finetune = torch.optim.SGD(model.parameters(),
                                         lr=cmd_args.finetune_lr,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    best_top1 = 0

    if model_name == "vgg16":
        if prune_type == "single":
            prune_stats = do_prune(model, params)
        else:
            prune_stats = do_prune_multiple(model, params)
    elif model_name == "resnet50":
        if prune_type == "multiple":
            prune_stats = do_prune_multiple_resnet50(model, params)
        else:
            raise ValueError(f"prune type {prune_type} is not implemented")
    elif model_name == "resnet56":
        if prune_type == "multiple":
            prune_stats = do_prune_resnet56(model, params)
        else:
            raise ValueError(f"prune type {prune_type} is not implemented")

    if cmd_args.multi_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    main_logger.info("Testing pruned model before finetune...")
    test(cmd_args, model, device, test_loader, main_logger)
    for epoch in range(cmd_args.prune_epochs):
        main_logger.info("# Finetune Epoch {} #".format(epoch + 1))
        train(
            cmd_args,
            model,
            device,
            train_loader,
            optimizer_finetune,
            epoch,
            main_logger,
        )
        main_logger.info("Testing finetuned model after pruning...")
        top1 = test(cmd_args, model, device, test_loader, main_logger)
        if top1 > best_top1:
            best_top1 = top1
            # export finetuned model

    info = {}
    info["top1"] = best_top1 / 100
    info["sparsity"] = prune_stats["sparsity"]
    info["value"] = -(best_top1 / 100 + prune_stats["sparsity"])
    info["value_sigma"] = 0.25
    return info
Exemplo n.º 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Cifar10 LeNet Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1e-5, metavar='LR',
                        help='learning rate (default: 1)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
    parser.add_argument('--T', type=int, default=60, metavar='N',
                        help='SNN time window')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True},
                     )
    mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
    std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
        ])
    im_aug = transforms.Compose([
        #transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
        transforms.RandomRotation(10),
        transforms.RandomCrop(32, padding = 6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize(mean, std)
        ])

    trainset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    
    for i in range(100):
        trainset = trainset + datasets.CIFAR10(root='./data', train=True, download=True, transform=im_aug)
        
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True)

    testset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False)

    snn_dataset = SpikeDataset(testset, T = args.T)
    snn_loader = torch.utils.data.DataLoader(snn_dataset, batch_size=10, shuffle=False)

    from models.vgg import VGG, CatVGG

    model = VGG('VGG19', clamp_max=1, quantize_bit=32).to(device)
    snn_model = CatVGG('VGG19', args.T).to(device)
    if args.resume != None:
        model.load_state_dict(torch.load(args.resume), strict=False)
        load_model(torch.load(args.resume), model)
        load_model(torch.load(args.resume), snn_model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        #test(model, device, train_loader)
        test(model, device, test_loader)

        #transfer_model(model, snn_model)
        #test(snn_model, device, snn_loader)
        if args.save_model:
            torch.save(model.state_dict(), "cifar_cnn_19.pt")
        
        scheduler.step()
    #test(model, device, train_loader)
    test(model, device, test_loader)
    transfer_model(model, snn_model)
    with torch.no_grad():
        normalize_weight(snn_model.features, quantize_bit=8)
    test(snn_model, device, snn_loader)
    if args.save_model:
        torch.save(model.state_dict(), "cifar_cnn_19.pt")
Exemplo n.º 6
0
    fuse_model_resnet(model)
else:
    fuse_model_vgg(model)
# model.to(gpu_device)

if pretrain_weight:
    if os.path.isfile(pretrain_weight):
        print("=> loading checkpoint '{}'".format(pretrain_weight))
        checkpoint = torch.load(pretrain_weight)
        #        args.start_epoch = checkpoint['epoch']
        #        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(pretrain_weight))
    else:
        print("=> no checkpoint found at '{}'".format(args.model))

ori_model_parameters = sum([param.nelement() for param in model.parameters()])

# origin model calc time
random_input = torch.rand((1, 3, 32, 32)).to(cpu_device)
model.to(cpu_device)
origin_forward_time, origin_flops, origin_params = calc_time_and_flops(
    random_input, model)
model.to(gpu_device)

acc = test(model)

print("=========== accuracy: {} ==============".format(acc))
print("=========== infer time: {} ==============".format(origin_forward_time))
print("=========== params: {} ==============".format(origin_params))
Exemplo n.º 7
0
                print("{} is not requires_grad".format(n))

    if Isparallel and torch.cuda.is_available():
        model = torch.nn.DataParallel(model)

    model.to(device)

    # ------------------------------------ step 3/5 : 定义损失函数和优化器 ------------------------------------

    loss_f = nn.CrossEntropyLoss().to(device)
    if args.low_lr:
        primary_params_id = [
            id(p) for n, p in model.named_parameters() if "primary" in n
        ]
        primary_params = filter(lambda p: id(p) in primary_params_id,
                                model.parameters())
        base_params = filter(lambda p: id(p) not in primary_params_id,
                             model.parameters())
        optimizer = optim.SGD([{
            'params': primary_params,
            'lr': 0.1 * args.lr
        }, {
            'params': base_params
        }],
                              lr=args.lr,
                              momentum=cfg.momentum,
                              weight_decay=cfg.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=cfg.momentum,
Exemplo n.º 8
0
class DPP(object):
    def __init__(self, args):
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.lr = args.lr
        self.epochs = args.epochs
        self.save_dir = './' + args.save_dir  #later change
        if (os.path.exists(self.save_dir) == False):
            os.mkdir(self.save_dir)

        if (args.model == 'vgg16'):
            self.model = VGG('VGG16', 0)
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=self.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
        elif (args.model == 'dpp_vgg16'):
            self.model = integrated_kernel(args)
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=self.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)

        #Parallel
        num_params = sum(p.numel() for p in self.model.parameters()
                         if p.requires_grad)
        print('The number of parametrs of models is', num_params)

        if (args.save_load):
            location = args.save_location
            print("locaton", location)
            checkpoint = torch.load(location)
            self.model.load_state_dict(checkpoint['state_dict'])

    def train(self, train_loader, test_loader, graph):
        #Declaration Model
        self.model.train()
        best_prec = 0
        losses = AverageMeter()
        top1 = AverageMeter()
        for epoch in range(self.epochs):
            #Test Accuarcy
            #self.adjust_learning_rate(epoch)
            for k, (inputs, target) in enumerate(train_loader):
                target = target.cuda(async=True)
                input_var = inputs.cuda()
                target_var = target
                output = self.model(input_var)
                loss = self.criterion(output, target_var)
                #Compute gradient and Do SGD step
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                #Measure accuracy and record loss
                prec1 = self.accuracy(output.data, target)[0]
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))

            graph.train_loss(losses.avg, epoch, 'train_loss')
            graph.train_acc(top1.avg, epoch, 'train_acc')
            prec = self.test(test_loader, epoch, graph)
            if (prec > best_prec):
                print("Acc", prec)
                best_prec = prec
                self.save_checkpoint(
                    {
                        'best_prec1': best_prec,
                        'state_dict': self.model.state_dict(),
                    },
                    filename=os.path.join(self.save_dir,
                                          'checkpoint_{}.tar'.format(epoch)))

    def test(self, test_loader, epoch, test_graph):
        self.model.eval()
        losses = AverageMeter()
        top1 = AverageMeter()
        for k, (inputs, target) in enumerate(test_loader):
            target = target.cuda()
            inputs = inputs.cuda()
            #Calculate each model
            #Compute gradient and Do SGD step
            output = self.model(inputs)
            loss = self.criterion(output, target)
            #Measure accuracy and record loss
            prec1 = self.accuracy(output.data, target)[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
        test_graph.test_loss(losses.avg, epoch, 'test_loss')
        test_graph.test_acc(top1.avg, epoch, 'test_acc')
        return top1.avg

    def accuracy(self, output, target, topk=(1, )):
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def adjust_learning_rate(self, epoch):
        self.lr = self.lr * (0.1**(epoch // 90))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
        torch.save(state, filename)
Exemplo n.º 9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch Cifar10 LeNet Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: 1)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        metavar='RESUME',
                        help='Resume model from checkpoint')
    parser.add_argument('--T',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='SNN time window')
    parser.add_argument('--k',
                        type=int,
                        default=100,
                        metavar='N',
                        help='Data augmentation')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({
            'num_workers': 1,
            'pin_memory': True,
            'shuffle': True
        }, )

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
    ])

    transform_test = transforms.Compose([transforms.ToTensor()])

    trainset = datasets.CIFAR10(root='./data',
                                train=True,
                                download=True,
                                transform=transform_train)
    train_loader_ = torch.utils.data.DataLoader(trainset,
                                                batch_size=512,
                                                shuffle=True)

    for i in range(args.k):

        im_aug = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.RandomCrop(32, padding=6),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            AddGaussianNoise(std=0.01)
        ])
        trainset = trainset + datasets.CIFAR10(
            root='./data', train=True, download=True, transform=im_aug)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=256 + 512,
                                               shuffle=True)

    testset = datasets.CIFAR10(root='./data',
                               train=False,
                               download=False,
                               transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=100,
                                              shuffle=False)

    snn_dataset = SpikeDataset(testset, T=args.T)
    snn_loader = torch.utils.data.DataLoader(snn_dataset,
                                             batch_size=10,
                                             shuffle=False)

    from models.vgg import VGG, VGG_, CatVGG, CatVGG_

    model = VGG('VGG16', clamp_max=1, quantize_bit=32, bias=False).to(device)
    snn_model = CatVGG('VGG16', args.T, bias=True).to(device)

    #Trainable pooling
    #model = VGG_('VGG19_', clamp_max=1, quantize_bit=32,bias =True).to(device)
    #snn_model = CatVGG_('VGG19_', args.T,bias =True).to(device)

    if args.resume != None:
        model.load_state_dict(torch.load(args.resume), strict=False)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    correct_ = 0
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, train_loader_)
        correct = test(model, device, test_loader)
        if correct > correct_:
            correct_ = correct
        scheduler.step()

    model = fuse_bn_recursively(model)
    transfer_model(model, snn_model)
    with torch.no_grad():
        normalize_weight(snn_model.features, quantize_bit=32)
    test(snn_model, device, snn_loader)
Exemplo n.º 10
0
    classes = ('plane', 'car', 'bird', 'cat', 'deeer',
                'dog', 'frog', 'horse', 'ship', 'truck')

    # setup device for training
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Setup best accuracy for comparing and model checkpoints
    best_accuracy = 0.0

    # setup Tensorboard file path
    writer = SummaryWriter('experiments/students/vgg/vgg11')

    # Configure the Network
    # You can swap out any kind of architectire from /models in here
    # Student model is VGG11 architecture
    model_fn = VGG('VGG11')
    model_fn = model_fn.to(device)
    cudnn.benchmark = True
    
    summary(model_fn, (3, 32, 32))

    # Setup the optimizer method for all the parameters
    # optimizer_fn = optim.SGD(model_fn.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    optimizer_fn = optim.SGD(model_fn.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    scheduler = StepLR(optimizer_fn, step_size=50, gamma=0.1)

    train_and_evaluate(model=model_fn, train_dataloader=trainloader, test_dataloader=testloader,
                        optimizer=optimizer_fn, scheduler=scheduler, total_epochs=200, temperature=temperature, alpha=alpha)

    writer.close()
Exemplo n.º 11
0

########################
# network
########################
if args.arch == "vgg":
    model = VGG(depth=args.depth, slim_channel=args.slim_channel)
else:
    # ResNet doesn't support slim channel strategy
    # model = ResNet(depth=args.depth)
    model = ResNet(depth=args.depth)

if torch.cuda.is_available():
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=momentum, weight_decay=weight_decay)

# load pre weight
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
Exemplo n.º 12
0
    return float(correct)


if opt.data == 'cifar10':
    n_class = 10
    epochs = [80, 60, 40, 20]
elif opt.data == 'restricted_imagenet':
    epochs = [30, 20, 20, 10]
elif opt.data == 'mnist':
    epochs = [60, 40, 20]
    n_class = 10
count = 0

u = torch.ones(nclass).cuda()
u.requires_grad = True
best_acc = 0
for epoch in epochs:
    optimizer_net = SGD(net.parameters(),
                        lr=opt.lr,
                        momentum=0.9,
                        weight_decay=5.0e-4)
    #optimizer_u = SGD([u], lr=opt.lr*0.2, momentum=0.9, weight_decay=5.0e-4)
    #optimizer_u = SGD([u], lr=opt.lr*5, momentum=0.9, weight_decay=5.0e-4)
    optimizer_u = SGD([u], lr=opt.lr, momentum=0.9, weight_decay=5.0e-4)
    for _ in range(epoch):
        train(count, u)
        test(count)
        #test_attack()
        count += 1
    opt.lr /= 10
Exemplo n.º 13
0
                net = torch.nn.DataParallel(net)
                cudnn.benchmark = True

            if args.resume:
                # Load checkpoint.
                print('==> Resuming from checkpoint..')
                assert os.path.isdir(
                    'checkpoint'), 'Error: no checkpoint directory found!'
                checkpoint = torch.load('./checkpoint/' + exp_name + '.t7')
                net.load_state_dict(checkpoint['net'])
                best_acc = checkpoint['acc']
                start_epoch = checkpoint['epoch']

            criterion = nn.CrossEntropyLoss()
            #optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
            optimizer = optim.Adam(net.parameters())
            #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

            train_loss = []
            train_acc = []

            val_loss = []
            val_acc = []

            for epoch in range(start_epoch, start_epoch + 200):
                #scheduler.step()
                l, a = train(epoch, progress_bar)
                _, _, tl, ta = test(epoch, progress_bar, exp_name)
                train_loss.append(l)
                train_acc.append(a)
                val_acc.append(ta)