def _get_classifier(self, opt):
     if opt.dataset == "mnist":
         classifier = NetC_MNIST()
     elif opt.dataset == "cifar10":
         classifier = PreActResNet18()
     elif opt.dataset == "gtsrb":
         classifier = PreActResNet18(num_classes=43)
     else:
         raise Exception("Invalid Dataset")
     # Load pretrained classifier
     ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
     if not os.path.exists(ckpt_folder):
         os.makedirs(ckpt_folder)
     ckpt_path = os.path.join(ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
     state_dict = torch.load(ckpt_path)
     classifier.load_state_dict(state_dict["netC"])
     for param in classifier.parameters():
         param.requires_grad = False
     classifier.eval()
     return classifier.to(opt.device)
def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None

    if opt.dataset == "cifar10" or opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
    if opt.dataset == "celeba":
        netC = ResNet18().to(opt.device)
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)

    # Optimizer
    optimizerC = torch.optim.SGD(netC.parameters(),
                                 opt.lr_C,
                                 momentum=0.9,
                                 weight_decay=5e-4)

    # Scheduler
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(
        optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)

    return netC, optimizerC, schedulerC
def train(opt):
    # Prepare model related things
    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    elif opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    netG = Generator(opt).to(opt.device)
    optimizerC = torch.optim.SGD(netC.parameters(),
                                 opt.lr_C,
                                 momentum=0.9,
                                 weight_decay=5e-4)
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  opt.lr_G,
                                  betas=(0.5, 0.9))
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(
        optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(
        optimizerG, opt.schedulerG_milestones, opt.schedulerG_lambda)

    netM = Generator(opt, out_channels=1).to(opt.device)
    optimizerM = torch.optim.Adam(netM.parameters(),
                                  opt.lr_M,
                                  betas=(0.5, 0.9))
    schedulerM = torch.optim.lr_scheduler.MultiStepLR(
        optimizerM, opt.schedulerM_milestones, opt.schedulerM_lambda)

    # For tensorboard
    log_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_dir = os.path.join(log_dir, "log_dir")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    tf_writer = SummaryWriter(log_dir=log_dir)

    # Continue training ?
    ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    ckpt_path = os.path.join(
        ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
    if os.path.exists(ckpt_path):
        state_dict = torch.load(ckpt_path)
        netC.load_state_dict(state_dict["netC"])
        netG.load_state_dict(state_dict["netG"])
        netM.load_state_dict(state_dict["netM"])
        epoch = state_dict["epoch"] + 1
        optimizerC.load_state_dict(state_dict["optimizerC"])
        optimizerG.load_state_dict(state_dict["optimizerG"])
        schedulerC.load_state_dict(state_dict["schedulerC"])
        schedulerG.load_state_dict(state_dict["schedulerG"])
        best_acc_clean = state_dict["best_acc_clean"]
        best_acc_bd = state_dict["best_acc_bd"]
        best_acc_cross = state_dict["best_acc_cross"]
        opt = state_dict["opt"]
        print("Continue training")
    else:
        # Prepare mask
        best_acc_clean = 0.0
        best_acc_bd = 0.0
        best_acc_cross = 0.0
        epoch = 1

        # Reset tensorboard
        shutil.rmtree(log_dir)
        os.makedirs(log_dir)
        print("Training from scratch")

    # Prepare dataset
    train_dl1 = get_dataloader(opt, train=True)
    train_dl2 = get_dataloader(opt, train=True)
    test_dl1 = get_dataloader(opt, train=False)
    test_dl2 = get_dataloader(opt, train=False)

    if epoch == 1:
        netM.train()
        for i in range(25):
            print(
                "Epoch {} - {} - {} | mask_density: {} - lambda_div: {}  - lambda_norm: {}:"
                .format(epoch, opt.dataset, opt.attack_mode, opt.mask_density,
                        opt.lambda_div, opt.lambda_norm))
            train_mask_step(netM, optimizerM, schedulerM, train_dl1, train_dl2,
                            epoch, opt, tf_writer)
            epoch = eval_mask(netM, optimizerM, schedulerM, test_dl1, test_dl2,
                              epoch, opt)
            epoch += 1
    netM.eval()
    netM.requires_grad_(False)

    for i in range(opt.n_iters):
        print("Epoch {} - {} - {} | mask_density: {} - lambda_div: {}:".format(
            epoch, opt.dataset, opt.attack_mode, opt.mask_density,
            opt.lambda_div))
        train_step(
            netC,
            netG,
            netM,
            optimizerC,
            optimizerG,
            schedulerC,
            schedulerG,
            train_dl1,
            train_dl2,
            epoch,
            opt,
            tf_writer,
        )

        best_acc_clean, best_acc_bd, best_acc_cross, epoch = eval(
            netC,
            netG,
            netM,
            optimizerC,
            optimizerG,
            schedulerC,
            schedulerG,
            test_dl1,
            test_dl2,
            epoch,
            best_acc_clean,
            best_acc_bd,
            best_acc_cross,
            opt,
        )
        epoch += 1
        if epoch > opt.n_iters:
            break
Exemplo n.º 4
0
def strip(opt, mode="clean"):
    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    elif opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    else:
        raise Exception("Invalid dataset")

    # Prepare pretrained classifier
    if opt.dataset == "mnist":
        netC = NetC_MNIST()
    elif opt.dataset == "cifar10":
        netC = PreActResNet18()
    else:
        netC = PreActResNet18(num_classes=43)
    for param in netC.parameters():
        param.requires_grad = False
    netC.to(opt.device)
    netC.eval()

    if mode != "clean":
        netG = Generator(opt)
        for param in netG.parameters():
            param.requires_grad = False
        netG.to(opt.device)
        netG.eval()

    # Load pretrained model
    ckpt_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    ckpt_path = os.path.join(
        ckpt_dir, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
    state_dict = torch.load(ckpt_path)
    netC.load_state_dict(state_dict["netC"])
    if mode != "clean":
        netG.load_state_dict(state_dict["netG"])
        netM = Generator(opt, out_channels=1)
        netM.load_state_dict(state_dict["netM"])
        netM.to(opt.device)
        netM.eval()
        netM.requires_grad_(False)

    # Prepare test set
    testset = get_dataset(opt, train=False)
    opt.bs = opt.n_test
    test_dataloader = get_dataloader(opt, train=False)

    # STRIP detector
    strip_detector = STRIP(opt)

    # Entropy list
    list_entropy_trojan = []
    list_entropy_benign = []

    if mode == "attack":
        # Testing with perturbed data
        print("Testing with bd data !!!!")
        inputs, targets = next(iter(test_dataloader))
        inputs = inputs.to(opt.device)
        patterns = netG(inputs)
        patterns = netG.normalize_pattern(patterns)
        batch_masks = netM.threshold(netM(inputs))
        bd_inputs = inputs + (patterns - inputs) * batch_masks

        bd_inputs = netG.denormalize_pattern(bd_inputs) * 255.0
        bd_inputs = bd_inputs.detach().cpu().numpy()
        bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose(
            (0, 2, 3, 1))
        for index in range(opt.n_test):
            background = bd_inputs[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_trojan.append(entropy)
            progress_bar(index, opt.n_test)

        # Testing with clean data
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
    else:
        # Testing with clean data
        print("Testing with clean data !!!!")
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
            progress_bar(index, opt.n_test)

    return list_entropy_trojan, list_entropy_benign