Ejemplo n.º 1
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if opt.dataset == "mnist" or opt.dataset == "cifar10":
        opt.num_classes = 10
    elif opt.dataset == "gtsrb":
        opt.num_classes = 43
    else:
        raise Exception("Invalid Dataset")
    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models and masks
    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    elif opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)
    )
    state_dict = torch.load(path_model)
    print("load C")
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print("load G")
    netG = Generator(opt)
    netG.load_state_dict(state_dict["netG"])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)
    print("load M")
    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict["netM"])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)
    test_dl2 = get_dataloader(opt, train=False)
    eval(netC, netG, netM, test_dl, test_dl2, opt)
def get_model(opt):
    netC = None
    optimizerC = None
    schedulerC = None

    if opt.dataset == "cifar10" or opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
    if opt.dataset == "celeba":
        netC = ResNet18().to(opt.device)
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)

    # Optimizer
    optimizerC = torch.optim.SGD(netC.parameters(),
                                 opt.lr_C,
                                 momentum=0.9,
                                 weight_decay=5e-4)

    # Scheduler
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(
        optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)

    return netC, optimizerC, schedulerC
 def _get_classifier(self, opt):
     if opt.dataset == "mnist":
         classifier = NetC_MNIST()
     elif opt.dataset == "cifar10":
         classifier = PreActResNet18()
     elif opt.dataset == "gtsrb":
         classifier = PreActResNet18(num_classes=43)
     else:
         raise Exception("Invalid Dataset")
     # Load pretrained classifier
     ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
     if not os.path.exists(ckpt_folder):
         os.makedirs(ckpt_folder)
     ckpt_path = os.path.join(ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
     state_dict = torch.load(ckpt_path)
     classifier.load_state_dict(state_dict["netC"])
     for param in classifier.parameters():
         param.requires_grad = False
     classifier.eval()
     return classifier.to(opt.device)
Ejemplo n.º 4
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()

    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid Dataset")

    mode = opt.attack_mode
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    opt.ckpt_path = os.path.join(
        opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode))
    opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")

    state_dict = torch.load(opt.ckpt_path)
    print("load C")
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print("load grid")
    identity_grid = state_dict["identity_grid"].to(opt.device)
    noise_grid = state_dict["noise_grid"].to(opt.device)
    print(state_dict["best_clean_acc"], state_dict["best_bd_acc"])

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    for name, module in netC._modules.items():
        print(name)

    # Forward hook for getting layer's output
    container = []

    def forward_hook(module, input, output):
        container.append(output)

    hook = netC.layer3.register_forward_hook(forward_hook)

    # Forwarding all the validation set
    print("Forwarding all the validation dataset:")
    for batch_idx, (inputs, _) in enumerate(test_dl):
        inputs = inputs.to(opt.device)
        netC(inputs)
        progress_bar(batch_idx, len(test_dl))

    # Processing to get the "more important mask"
    container = torch.cat(container, dim=0)
    activation = torch.mean(container, dim=[0, 2, 3])
    seq_sort = torch.argsort(activation)
    pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool)
    hook.remove()

    # Pruning times - no-tuning after pruning a channel!!!
    acc_clean = []
    acc_bd = []
    with open("mnist_{}_results.txt".format(opt.attack_mode), "w") as outs:
        for index in range(pruning_mask.shape[0]):
            net_pruned = copy.deepcopy(netC)
            num_pruned = index
            if index:
                channel = seq_sort[index - 1]
                pruning_mask[channel] = False
            print("Pruned {} filters".format(num_pruned))

            net_pruned.layer3.conv1 = nn.Conv2d(pruning_mask.shape[0],
                                                pruning_mask.shape[0] -
                                                num_pruned, (3, 3),
                                                stride=2,
                                                padding=1,
                                                bias=False)
            net_pruned.linear6 = nn.Linear(
                (pruning_mask.shape[0] - num_pruned) * 16, 512)

            # Re-assigning weight to the pruned net
            for name, module in net_pruned._modules.items():
                if "layer3" in name:
                    module.conv1.weight.data = netC.layer3.conv1.weight.data[
                        pruning_mask]
                    module.ind = pruning_mask
                elif "linear6" == name:
                    module.weight.data = netC.linear6.weight.data.reshape(
                        -1, 64,
                        16)[:, pruning_mask].reshape(512,
                                                     -1)  # [:, pruning_mask]
                    module.bias.data = netC.linear6.bias.data
                else:
                    continue
            net_pruned.to(opt.device)
            clean, bd = eval(net_pruned, identity_grid, noise_grid, test_dl,
                             opt)
            outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
Ejemplo n.º 5
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if (opt.dataset == 'mnist' or opt.dataset == 'cifar10'):
        opt.num_classes = 10
    elif (opt.dataset == 'gtsrb'):
        opt.num_classes = 43
    else:
        raise Exception("Invalid Dataset")
    if (opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif (opt.dataset == 'gtsrb'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif (opt.dataset == 'mnist'):
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models and masks
    if (opt.dataset == 'cifar10'):
        netC = PreActResNet18().to(opt.device)
    elif (opt.dataset == 'gtsrb'):
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif (opt.dataset == 'mnist'):
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode,
        '{}_{}_ckpt.pth.tar'.format(opt.attack_mode, opt.dataset))
    state_dict = torch.load(path_model)
    print('load C')
    netC.load_state_dict(state_dict['netC'])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print('load G')
    netG = Generator(opt)
    netG.load_state_dict(state_dict['netG'])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)

    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict['netM'])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    print('Original')
    eval(netC, netG, netM, test_dl, opt)
    print('Smoothing')
    for k in [3, 5]:
        print('k = ', k)
        test_dl2 = get_dataloader(opt, train=False, k=k)
        eval(netC, netG, netM, test_dl2, opt)

    print('Color-depth shrinking')
    for cc in range(3):
        c = cc + 1
        print('c = ', c)
        test_dl2 = get_dataloader(opt, train=False, c=c)
        eval(netC, netG, netM, test_dl2, opt)
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if opt.dataset == "mnist":
        opt.num_classes = 10
    else:
        raise Exception("Invalid Dataset")
    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)
    )
    state_dict = torch.load(path_model)
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    netG = Generator(opt)
    netG.load_state_dict(state_dict["netG"])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)

    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict["netM"])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    # Forward hook for getting layer's output
    container = []

    def forward_hook(module, input, output):
        container.append(output)

    hook = netC.relu6.register_forward_hook(forward_hook)

    # Forwarding all the validation set
    print("Forwarding all the validation dataset:")
    for batch_idx, (inputs, _) in enumerate(test_dl):
        inputs = inputs.to(opt.device)
        netC(inputs)
        progress_bar(batch_idx, len(test_dl))

    # Processing to get the "more important mask"
    container = torch.cat(container, dim=0)
    activation = torch.mean(container, dim=[0, 2, 3])
    seq_sort = torch.argsort(activation)
    pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool)
    hook.remove()

    # Pruning times - no-tuning after pruning a channel!!!
    acc_clean = []
    acc_bd = []
    with open(opt.outfile, "w") as outs:
        for index in range(pruning_mask.shape[0]):
            net_pruned = copy.deepcopy(netC)
            num_pruned = index
            if index:
                channel = seq_sort[index]
                pruning_mask[channel] = False
            print("Pruned {} filters".format(num_pruned))

            net_pruned.conv5 = nn.Conv2d(64, 64 - num_pruned, (5, 5), 1, 0)
            net_pruned.linear6 = nn.Linear(16 * (64 - num_pruned), 512)

            # Re-assigning weight to the pruned net
            for name, module in net_pruned._modules.items():
                if "conv5" in name:
                    module.weight.data = netC.conv5.weight.data[pruning_mask]
                    module.bias.data = netC.conv5.bias.data[pruning_mask]
                elif "linear6" in name:
                    module.weight.data = netC.linear6.weight.data.reshape(-1, 64, 16)[:, pruning_mask].reshape(512, -1)
                    module.bias.data = netC.linear6.bias.data
                else:
                    continue
            clean, bd = eval(net_pruned, netG, netM, test_dl, opt)
            outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
def train(opt):
    # Prepare model related things
    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    elif opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    netG = Generator(opt).to(opt.device)
    optimizerC = torch.optim.SGD(netC.parameters(),
                                 opt.lr_C,
                                 momentum=0.9,
                                 weight_decay=5e-4)
    optimizerG = torch.optim.Adam(netG.parameters(),
                                  opt.lr_G,
                                  betas=(0.5, 0.9))
    schedulerC = torch.optim.lr_scheduler.MultiStepLR(
        optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(
        optimizerG, opt.schedulerG_milestones, opt.schedulerG_lambda)

    netM = Generator(opt, out_channels=1).to(opt.device)
    optimizerM = torch.optim.Adam(netM.parameters(),
                                  opt.lr_M,
                                  betas=(0.5, 0.9))
    schedulerM = torch.optim.lr_scheduler.MultiStepLR(
        optimizerM, opt.schedulerM_milestones, opt.schedulerM_lambda)

    # For tensorboard
    log_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_dir = os.path.join(log_dir, "log_dir")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    tf_writer = SummaryWriter(log_dir=log_dir)

    # Continue training ?
    ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    ckpt_path = os.path.join(
        ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
    if os.path.exists(ckpt_path):
        state_dict = torch.load(ckpt_path)
        netC.load_state_dict(state_dict["netC"])
        netG.load_state_dict(state_dict["netG"])
        netM.load_state_dict(state_dict["netM"])
        epoch = state_dict["epoch"] + 1
        optimizerC.load_state_dict(state_dict["optimizerC"])
        optimizerG.load_state_dict(state_dict["optimizerG"])
        schedulerC.load_state_dict(state_dict["schedulerC"])
        schedulerG.load_state_dict(state_dict["schedulerG"])
        best_acc_clean = state_dict["best_acc_clean"]
        best_acc_bd = state_dict["best_acc_bd"]
        best_acc_cross = state_dict["best_acc_cross"]
        opt = state_dict["opt"]
        print("Continue training")
    else:
        # Prepare mask
        best_acc_clean = 0.0
        best_acc_bd = 0.0
        best_acc_cross = 0.0
        epoch = 1

        # Reset tensorboard
        shutil.rmtree(log_dir)
        os.makedirs(log_dir)
        print("Training from scratch")

    # Prepare dataset
    train_dl1 = get_dataloader(opt, train=True)
    train_dl2 = get_dataloader(opt, train=True)
    test_dl1 = get_dataloader(opt, train=False)
    test_dl2 = get_dataloader(opt, train=False)

    if epoch == 1:
        netM.train()
        for i in range(25):
            print(
                "Epoch {} - {} - {} | mask_density: {} - lambda_div: {}  - lambda_norm: {}:"
                .format(epoch, opt.dataset, opt.attack_mode, opt.mask_density,
                        opt.lambda_div, opt.lambda_norm))
            train_mask_step(netM, optimizerM, schedulerM, train_dl1, train_dl2,
                            epoch, opt, tf_writer)
            epoch = eval_mask(netM, optimizerM, schedulerM, test_dl1, test_dl2,
                              epoch, opt)
            epoch += 1
    netM.eval()
    netM.requires_grad_(False)

    for i in range(opt.n_iters):
        print("Epoch {} - {} - {} | mask_density: {} - lambda_div: {}:".format(
            epoch, opt.dataset, opt.attack_mode, opt.mask_density,
            opt.lambda_div))
        train_step(
            netC,
            netG,
            netM,
            optimizerC,
            optimizerG,
            schedulerC,
            schedulerG,
            train_dl1,
            train_dl2,
            epoch,
            opt,
            tf_writer,
        )

        best_acc_clean, best_acc_bd, best_acc_cross, epoch = eval(
            netC,
            netG,
            netM,
            optimizerC,
            optimizerG,
            schedulerC,
            schedulerG,
            test_dl1,
            test_dl2,
            epoch,
            best_acc_clean,
            best_acc_bd,
            best_acc_cross,
            opt,
        )
        epoch += 1
        if epoch > opt.n_iters:
            break
Ejemplo n.º 8
0
def strip(opt, mode="clean"):

    # Prepare pretrained classifier
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    elif opt.dataset == "cifar10" or opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
    elif opt.dataset == "celeba":
        netC = ResNet18().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    # Load pretrained model
    mode = opt.attack_mode
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    opt.ckpt_path = os.path.join(
        opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode))
    opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")

    state_dict = torch.load(opt.ckpt_path)
    netC.load_state_dict(state_dict["netC"])
    if mode != "clean":
        identity_grid = state_dict["identity_grid"]
        noise_grid = state_dict["noise_grid"]
    netC.requires_grad_(False)
    netC.eval()
    netC.to(opt.device)

    # Prepare test set
    testset = get_dataset(opt, train=False)
    opt.bs = opt.n_test
    test_dataloader = get_dataloader(opt, train=False)
    denormalizer = Denormalizer(opt)

    # STRIP detector
    strip_detector = STRIP(opt)

    # Entropy list
    list_entropy_trojan = []
    list_entropy_benign = []

    if mode == "attack":
        # Testing with perturbed data
        print("Testing with bd data !!!!")
        inputs, targets = next(iter(test_dataloader))
        inputs = inputs.to(opt.device)
        bd_inputs = create_backdoor(inputs, identity_grid, noise_grid, opt)

        bd_inputs = denormalizer(bd_inputs) * 255.0
        bd_inputs = bd_inputs.detach().cpu().numpy()
        bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose(
            (0, 2, 3, 1))
        for index in range(opt.n_test):
            background = bd_inputs[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_trojan.append(entropy)
            progress_bar(index, opt.n_test)

        # Testing with clean data
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
    else:
        # Testing with clean data
        print("Testing with clean data !!!!")
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
            progress_bar(index, opt.n_test)

    return list_entropy_trojan, list_entropy_benign
Ejemplo n.º 9
0
def strip(opt, mode="clean"):
    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    elif opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    else:
        raise Exception("Invalid dataset")

    # Prepare pretrained classifier
    if opt.dataset == "mnist":
        netC = NetC_MNIST()
    elif opt.dataset == "cifar10":
        netC = PreActResNet18()
    else:
        netC = PreActResNet18(num_classes=43)
    for param in netC.parameters():
        param.requires_grad = False
    netC.to(opt.device)
    netC.eval()

    if mode != "clean":
        netG = Generator(opt)
        for param in netG.parameters():
            param.requires_grad = False
        netG.to(opt.device)
        netG.eval()

    # Load pretrained model
    ckpt_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode)
    ckpt_path = os.path.join(
        ckpt_dir, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset))
    state_dict = torch.load(ckpt_path)
    netC.load_state_dict(state_dict["netC"])
    if mode != "clean":
        netG.load_state_dict(state_dict["netG"])
        netM = Generator(opt, out_channels=1)
        netM.load_state_dict(state_dict["netM"])
        netM.to(opt.device)
        netM.eval()
        netM.requires_grad_(False)

    # Prepare test set
    testset = get_dataset(opt, train=False)
    opt.bs = opt.n_test
    test_dataloader = get_dataloader(opt, train=False)

    # STRIP detector
    strip_detector = STRIP(opt)

    # Entropy list
    list_entropy_trojan = []
    list_entropy_benign = []

    if mode == "attack":
        # Testing with perturbed data
        print("Testing with bd data !!!!")
        inputs, targets = next(iter(test_dataloader))
        inputs = inputs.to(opt.device)
        patterns = netG(inputs)
        patterns = netG.normalize_pattern(patterns)
        batch_masks = netM.threshold(netM(inputs))
        bd_inputs = inputs + (patterns - inputs) * batch_masks

        bd_inputs = netG.denormalize_pattern(bd_inputs) * 255.0
        bd_inputs = bd_inputs.detach().cpu().numpy()
        bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose(
            (0, 2, 3, 1))
        for index in range(opt.n_test):
            background = bd_inputs[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_trojan.append(entropy)
            progress_bar(index, opt.n_test)

        # Testing with clean data
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
    else:
        # Testing with clean data
        print("Testing with clean data !!!!")
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
            progress_bar(index, opt.n_test)

    return list_entropy_trojan, list_entropy_benign