Example #1
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if opt.dataset == "mnist" or opt.dataset == "cifar10":
        opt.num_classes = 10
    elif opt.dataset == "gtsrb":
        opt.num_classes = 43
    else:
        raise Exception("Invalid Dataset")
    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models and masks
    if opt.dataset == "cifar10":
        netC = PreActResNet18().to(opt.device)
    elif opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)
    )
    state_dict = torch.load(path_model)
    print("load C")
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print("load G")
    netG = Generator(opt)
    netG.load_state_dict(state_dict["netG"])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)
    print("load M")
    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict["netM"])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)
    test_dl2 = get_dataloader(opt, train=False)
    eval(netC, netG, netM, test_dl, test_dl2, opt)
Example #2
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()

    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid Dataset")

    mode = opt.attack_mode
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    opt.ckpt_path = os.path.join(
        opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode))
    opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")

    state_dict = torch.load(opt.ckpt_path)
    print("load C")
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print("load grid")
    identity_grid = state_dict["identity_grid"].to(opt.device)
    noise_grid = state_dict["noise_grid"].to(opt.device)
    print(state_dict["best_clean_acc"], state_dict["best_bd_acc"])

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    for name, module in netC._modules.items():
        print(name)

    # Forward hook for getting layer's output
    container = []

    def forward_hook(module, input, output):
        container.append(output)

    hook = netC.layer3.register_forward_hook(forward_hook)

    # Forwarding all the validation set
    print("Forwarding all the validation dataset:")
    for batch_idx, (inputs, _) in enumerate(test_dl):
        inputs = inputs.to(opt.device)
        netC(inputs)
        progress_bar(batch_idx, len(test_dl))

    # Processing to get the "more important mask"
    container = torch.cat(container, dim=0)
    activation = torch.mean(container, dim=[0, 2, 3])
    seq_sort = torch.argsort(activation)
    pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool)
    hook.remove()

    # Pruning times - no-tuning after pruning a channel!!!
    acc_clean = []
    acc_bd = []
    with open("mnist_{}_results.txt".format(opt.attack_mode), "w") as outs:
        for index in range(pruning_mask.shape[0]):
            net_pruned = copy.deepcopy(netC)
            num_pruned = index
            if index:
                channel = seq_sort[index - 1]
                pruning_mask[channel] = False
            print("Pruned {} filters".format(num_pruned))

            net_pruned.layer3.conv1 = nn.Conv2d(pruning_mask.shape[0],
                                                pruning_mask.shape[0] -
                                                num_pruned, (3, 3),
                                                stride=2,
                                                padding=1,
                                                bias=False)
            net_pruned.linear6 = nn.Linear(
                (pruning_mask.shape[0] - num_pruned) * 16, 512)

            # Re-assigning weight to the pruned net
            for name, module in net_pruned._modules.items():
                if "layer3" in name:
                    module.conv1.weight.data = netC.layer3.conv1.weight.data[
                        pruning_mask]
                    module.ind = pruning_mask
                elif "linear6" == name:
                    module.weight.data = netC.linear6.weight.data.reshape(
                        -1, 64,
                        16)[:, pruning_mask].reshape(512,
                                                     -1)  # [:, pruning_mask]
                    module.bias.data = netC.linear6.bias.data
                else:
                    continue
            net_pruned.to(opt.device)
            clean, bd = eval(net_pruned, identity_grid, noise_grid, test_dl,
                             opt)
            outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
Example #3
0
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if (opt.dataset == 'mnist' or opt.dataset == 'cifar10'):
        opt.num_classes = 10
    elif (opt.dataset == 'gtsrb'):
        opt.num_classes = 43
    else:
        raise Exception("Invalid Dataset")
    if (opt.dataset == 'cifar10'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif (opt.dataset == 'gtsrb'):
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif (opt.dataset == 'mnist'):
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models and masks
    if (opt.dataset == 'cifar10'):
        netC = PreActResNet18().to(opt.device)
    elif (opt.dataset == 'gtsrb'):
        netC = PreActResNet18(num_classes=43).to(opt.device)
    elif (opt.dataset == 'mnist'):
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode,
        '{}_{}_ckpt.pth.tar'.format(opt.attack_mode, opt.dataset))
    state_dict = torch.load(path_model)
    print('load C')
    netC.load_state_dict(state_dict['netC'])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    print('load G')
    netG = Generator(opt)
    netG.load_state_dict(state_dict['netG'])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)

    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict['netM'])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    print('Original')
    eval(netC, netG, netM, test_dl, opt)
    print('Smoothing')
    for k in [3, 5]:
        print('k = ', k)
        test_dl2 = get_dataloader(opt, train=False, k=k)
        eval(netC, netG, netM, test_dl2, opt)

    print('Color-depth shrinking')
    for cc in range(3):
        c = cc + 1
        print('c = ', c)
        test_dl2 = get_dataloader(opt, train=False, c=c)
        eval(netC, netG, netM, test_dl2, opt)
def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    if opt.dataset == "mnist":
        opt.num_classes = 10
    else:
        raise Exception("Invalid Dataset")
    if opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    else:
        raise Exception("Invalid Dataset")

    # Load models
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    path_model = os.path.join(
        opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)
    )
    state_dict = torch.load(path_model)
    netC.load_state_dict(state_dict["netC"])
    netC.to(opt.device)
    netC.eval()
    netC.requires_grad_(False)
    netG = Generator(opt)
    netG.load_state_dict(state_dict["netG"])
    netG.to(opt.device)
    netG.eval()
    netG.requires_grad_(False)

    netM = Generator(opt, out_channels=1)
    netM.load_state_dict(state_dict["netM"])
    netM.to(opt.device)
    netM.eval()
    netM.requires_grad_(False)

    # Prepare dataloader
    test_dl = get_dataloader(opt, train=False)

    # Forward hook for getting layer's output
    container = []

    def forward_hook(module, input, output):
        container.append(output)

    hook = netC.relu6.register_forward_hook(forward_hook)

    # Forwarding all the validation set
    print("Forwarding all the validation dataset:")
    for batch_idx, (inputs, _) in enumerate(test_dl):
        inputs = inputs.to(opt.device)
        netC(inputs)
        progress_bar(batch_idx, len(test_dl))

    # Processing to get the "more important mask"
    container = torch.cat(container, dim=0)
    activation = torch.mean(container, dim=[0, 2, 3])
    seq_sort = torch.argsort(activation)
    pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool)
    hook.remove()

    # Pruning times - no-tuning after pruning a channel!!!
    acc_clean = []
    acc_bd = []
    with open(opt.outfile, "w") as outs:
        for index in range(pruning_mask.shape[0]):
            net_pruned = copy.deepcopy(netC)
            num_pruned = index
            if index:
                channel = seq_sort[index]
                pruning_mask[channel] = False
            print("Pruned {} filters".format(num_pruned))

            net_pruned.conv5 = nn.Conv2d(64, 64 - num_pruned, (5, 5), 1, 0)
            net_pruned.linear6 = nn.Linear(16 * (64 - num_pruned), 512)

            # Re-assigning weight to the pruned net
            for name, module in net_pruned._modules.items():
                if "conv5" in name:
                    module.weight.data = netC.conv5.weight.data[pruning_mask]
                    module.bias.data = netC.conv5.bias.data[pruning_mask]
                elif "linear6" in name:
                    module.weight.data = netC.linear6.weight.data.reshape(-1, 64, 16)[:, pruning_mask].reshape(512, -1)
                    module.bias.data = netC.linear6.bias.data
                else:
                    continue
            clean, bd = eval(net_pruned, netG, netM, test_dl, opt)
            outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
Example #5
0
def strip(opt, mode="clean"):

    # Prepare pretrained classifier
    if opt.dataset == "mnist":
        netC = NetC_MNIST().to(opt.device)
    elif opt.dataset == "cifar10" or opt.dataset == "gtsrb":
        netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
    elif opt.dataset == "celeba":
        netC = ResNet18().to(opt.device)
    else:
        raise Exception("Invalid dataset")

    # Load pretrained model
    mode = opt.attack_mode
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    opt.ckpt_path = os.path.join(
        opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode))
    opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir")

    state_dict = torch.load(opt.ckpt_path)
    netC.load_state_dict(state_dict["netC"])
    if mode != "clean":
        identity_grid = state_dict["identity_grid"]
        noise_grid = state_dict["noise_grid"]
    netC.requires_grad_(False)
    netC.eval()
    netC.to(opt.device)

    # Prepare test set
    testset = get_dataset(opt, train=False)
    opt.bs = opt.n_test
    test_dataloader = get_dataloader(opt, train=False)
    denormalizer = Denormalizer(opt)

    # STRIP detector
    strip_detector = STRIP(opt)

    # Entropy list
    list_entropy_trojan = []
    list_entropy_benign = []

    if mode == "attack":
        # Testing with perturbed data
        print("Testing with bd data !!!!")
        inputs, targets = next(iter(test_dataloader))
        inputs = inputs.to(opt.device)
        bd_inputs = create_backdoor(inputs, identity_grid, noise_grid, opt)

        bd_inputs = denormalizer(bd_inputs) * 255.0
        bd_inputs = bd_inputs.detach().cpu().numpy()
        bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose(
            (0, 2, 3, 1))
        for index in range(opt.n_test):
            background = bd_inputs[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_trojan.append(entropy)
            progress_bar(index, opt.n_test)

        # Testing with clean data
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
    else:
        # Testing with clean data
        print("Testing with clean data !!!!")
        for index in range(opt.n_test):
            background, _ = testset[index]
            entropy = strip_detector(background, testset, netC)
            list_entropy_benign.append(entropy)
            progress_bar(index, opt.n_test)

    return list_entropy_trojan, list_entropy_benign