Esempio n. 1
0
def main2(args):
    best_prec1 = 0.0

    torch.backends.cudnn.deterministic = not args.cudaNoise

    torch.manual_seed(time.time())

    if args.init != "None":
        args.name = "lrnet_%s" % args.init

    if args.tensorboard:
        configure(f"runs/{args.name}")

    dstype = nondigits(args.dataset)
    if dstype == "cifar":
        means = [125.3, 123.0, 113.9]
        stds = [63.0, 62.1, 66.7]
    elif dstype == "imgnet":
        means = [123.3, 118.1, 108.0]
        stds = [54.1, 52.6, 53.2]

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in means],
        std=[x / 255.0 for x in stds],
    )

    writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args))
    args.classes = onlydigits(args.dataset)

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode="reflect").squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length))

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}

    assert dstype in ["cifar", "cinic", "imgnet"]

    if dstype == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=False,
                                                    transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "cinic":
        cinic_directory = "%s/cinic10" % args.dir
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/train',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        print("Using CINIC10 dataset")
        val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/valid',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "imgnet":
        print("Using converted imagenet")
        train_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=True,
                   transform=transform_train,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=False,
                   transform=transform_test,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    else:
        print("Error matching dataset %s" % dstype)

    ##print("main bn:")
    ##print(args.batchnorm)
    ##print("main fixup:")
    ##print(args.fixup)

    if args.prune:
        pruner_state = getPruneMask(args)
        if pruner_state is None:
            print("Failed to prune network, aborting")
            return None

    if args.arch.lower() == "constnet":
        model = WideResNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
            dropl1=args.dropl1,
        )
    elif args.arch.lower() == "leakynet":
        model = LRNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
        )
    else:
        print("arch %s is not supported" % args.arch)
        return None

    ##draw(args,model)  complex installation

    param_num = sum([p.data.nelement() for p in model.parameters()])

    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:

        start = int(args.device[0])
        end = int(args.device[2]) + 1
        torch.cuda.set_device(start)
        dev_list = []
        for i in range(start, end):
            dev_list.append("cuda:%d" % i)
        model = torch.nn.DataParallel(model, device_ids=dev_list)

    model = model.cuda()

    if args.freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1
                if cnt == args.freeze:
                    break

            if cnt >= args.freeze_start:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name.split('.')[1:3]  )
                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    elif args.freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1

            if cnt > args.layers - 3 + args.freeze - 1:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name  )

                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.res_freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > args.res_freeze_start:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)
                if cnt >= args.res_freeze:
                    break
    elif args.res_freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > 3 + args.res_freeze:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.prune:
        if args.prune_epoch >= 100:
            weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune
        else:
            weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % (
                args.prune, args.prune_epoch)

        if os.path.isfile(weightsFile):
            print(f"=> loading checkpoint {weightsFile}")
            checkpoint = torch.load(weightsFile)
            model.load_state_dict(checkpoint["state_dict"])
            print(
                f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})"
            )
        else:
            if args.prune_epoch == 0:
                print(f"=> No source data, Restarting network from scratch")
            else:
                print(f"=> no checkpoint found at {weightsFile}, aborting...")
                return None

    else:
        if args.resume:
            tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume
            if os.path.isfile(tarfile):
                print(f"=> loading checkpoint {args.resume}")
                checkpoint = torch.load(tarfile)
                args.start_epoch = checkpoint["epoch"]
                best_prec1 = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                print(
                    f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at {tarfile}, aborting...")
                return None

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)

    if args.prune and pruner_state is not None:
        cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        params_retrain = get_params_for_pruning(args, model)
        pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain,
                                                      cutoff_retrain)
        pruner_retrain.load_state_dict(pruner_state)
        pruner_retrain.prune(update_state=False)
        pruned_weights_count = count_pruned_weights(params_retrain,
                                                    args.cutoff)
        params_left = param_num - pruned_weights_count
        print("Pruned %d weights, New model size:  %d/%d (%d%%)" %
              (pruned_weights_count, params_left, param_num,
               int(100 * params_left / param_num)))

    else:
        pruner_retrain = None

    if args.eval:
        best_prec1 = validate(args, val_loader, model, criterion, 0, None)
    else:

        if args.varnet:
            save_checkpoint(
                args,
                {
                    "epoch": 0,
                    "state_dict": model.state_dict(),
                    "best_prec1": 0.0,
                },
                True,
            )
            best_prec1 = 0.0

        turns_above_50 = 0

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(args, optimizer, epoch + 1)
            train(args, train_loader, model, criterion, optimizer, epoch,
                  pruner_retrain, writer)

            prec1 = validate(args, val_loader, model, criterion, epoch, writer)
            correlation.measure_correlation(model, epoch, writer=writer)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            if args.savenet:
                save_checkpoint(
                    args,
                    {
                        "epoch": epoch + 1,
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec1,
                    },
                    is_best,
                )
            if args.symmetry_break:
                if prec1 > 50.0:
                    turns_above_50 += 1
                    if turns_above_50 > 3:
                        return epoch

    writer.close()

    print("Best accuracy: ", best_prec1)
    return best_prec1
def main(args):

    save_folder = '%s_%s' % (args.dataset, args.affix)

    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, args.todo, 'info')

    print_args(args, logger)

    # Using a WideResNet model
    model = WideResNet(depth=34, num_classes=10, widen_factor=1, dropRate=0.0)
    flop, param = get_model_infos(model, (1, 3, 32, 32))
    logger.info('Model Info: FLOP = {:.2f} M, Params = {:.2f} MB'.format(
        flop, param))

    mean = [0]
    std = [1]

    inputs_box = (min((0 - m) / s for m, s in zip(mean, std)),
                  max((1 - m) / s for m, s in zip(mean, std)))

    attack = carlini_wagner_L2.L2Adversary(targeted=False,
                                           confidence=0.0,
                                           search_steps=10,
                                           box=inputs_box,
                                           optimizer_lr=5e-4)

    if torch.cuda.is_available():
        model.cuda()

    trainer = Trainer(args, logger, attack)

    if args.todo == 'train':
        transform_train = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Lambda(lambda x: F.pad(
                x.unsqueeze(0),
                (4, 4, 4, 4), mode='constant', value=0).squeeze()),
            tv.transforms.ToPILImage(),
            tv.transforms.RandomCrop(32),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
        ])
        tr_dataset = tv.datasets.CIFAR10(args.data_root,
                                         train=True,
                                         transform=transform_train,
                                         download=True)

        tr_loader = DataLoader(tr_dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=4)

        # evaluation during training
        te_dataset = tv.datasets.CIFAR10(args.data_root,
                                         train=False,
                                         transform=tv.transforms.ToTensor(),
                                         download=True)

        te_loader = DataLoader(te_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=4)

        trainer.train(model, tr_loader, te_loader, args.adv_train)
    elif args.todo == 'test':
        pass
    elif args.todo == 'cw_test':
        model = WideResNet(depth=34,
                           num_classes=10,
                           widen_factor=2,
                           dropRate=0.0)
        print(model)
        model.load_state_dict(
            torch.load(args.cw_attack_modelpath,
                       map_location=lambda storage, loc: storage))
        model.cuda()
        te_dataset = tv.datasets.CIFAR10(args.data_root,
                                         train=False,
                                         transform=tv.transforms.ToTensor(),
                                         download=True)

        te_loader = DataLoader(te_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=4)
        cw_attack_test(model, te_loader)

    else:
        raise NotImplementedError
Esempio n. 3
0
def getPruneMask(args):
    baseTar = "runs/%s-net/checkpoint.pth.tar" % args.prune
    if os.path.isfile(baseTar):

        classes = onlydigits(args.prune_classes)
        if classes == 0:
            classes = args.classes

        fullModel = WideResNet(
            args.layers,
            classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
        )

        if torch.cuda.device_count() > 1:

            start = int(args.device[0])
            end = int(args.device[2]) + 1
            torch.cuda.set_device(start)
            dev_list = []
            for i in range(start, end):
                dev_list.append("cuda:%d" % i)
            fullModel = torch.nn.DataParallel(fullModel, device_ids=dev_list)

        fullModel = fullModel.cuda()

        print(f"=> loading checkpoint {baseTar}")

        checkpoint = torch.load(baseTar)
        fullModel.load_state_dict(checkpoint["state_dict"])

        # --------------------------- #
        # --- Pruning Setup Start --- #

        cutoff = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        # don't prune the final bias weights
        params = get_params_for_pruning(args, fullModel)

        print(params)

        pruner = prunhild.pruner.CutoffPruner(params,
                                              cutoff,
                                              prune_online=True)
        pruner.prune()

        print(
            f"=> loaded checkpoint '{baseTar}' (epoch {checkpoint['epoch']})")

        if torch.cuda.device_count() > 1:
            start = int(args.device[0])
            end = int(args.device[2]) + 1
            for i in range(start, end):
                torch.cuda.set_device(i)
                torch.cuda.empty_cache()

        mask = pruner.state_dict()
        if args.randomize_mask:
            mask = randomize_mask(mask, args.cutoff)

        return mask
    else:
        print(f"=> no checkpoint found at {baseTar}")
        return None
Esempio n. 4
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.tensorboard:
        configure(f"runs/{args.name}")

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )

    if args.augment:
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: F.pad(
                        x.unsqueeze(0), (4, 4, 4, 4), mode="reflect"
                    ).squeeze()
                ),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    else:
        transform_train = transforms.Compose([transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length)
        )

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}
    assert args.dataset == "cifar10" or args.dataset == "cifar100"

    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=True, download=True, transform=transform_train
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=False, transform=transform_test
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )

    model = WideResNet(
        args.layers,
        args.dataset == "cifar10" and 10 or 100,
        args.widen_factor,
        droprate=args.droprate,
        use_bn=args.batchnorm,
        use_fixup=args.fixup,
    )

    param_num = sum([p.data.nelement() for p in model.parameters()])
    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint {args.resume}")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            print(f"=> no checkpoint found at {args.resume}")

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        nesterov=args.nesterov,
        weight_decay=args.weight_decay,
    )

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)
        train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, epoch)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
            },
            is_best,
        )

    print("Best accuracy: ", best_prec1)