def main(): # Prepare arguments opt = get_arguments().parse_args() if opt.dataset == "mnist" or opt.dataset == "cifar10": opt.num_classes = 10 elif opt.dataset == "gtsrb": opt.num_classes = 43 else: raise Exception("Invalid Dataset") if opt.dataset == "cifar10": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "gtsrb": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 else: raise Exception("Invalid Dataset") # Load models and masks if opt.dataset == "cifar10": netC = PreActResNet18().to(opt.device) elif opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=43).to(opt.device) elif opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid dataset") path_model = os.path.join( opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset) ) state_dict = torch.load(path_model) print("load C") netC.load_state_dict(state_dict["netC"]) netC.to(opt.device) netC.eval() netC.requires_grad_(False) print("load G") netG = Generator(opt) netG.load_state_dict(state_dict["netG"]) netG.to(opt.device) netG.eval() netG.requires_grad_(False) print("load M") netM = Generator(opt, out_channels=1) netM.load_state_dict(state_dict["netM"]) netM.to(opt.device) netM.eval() netM.requires_grad_(False) # Prepare dataloader test_dl = get_dataloader(opt, train=False) test_dl2 = get_dataloader(opt, train=False) eval(netC, netG, netM, test_dl, test_dl2, opt)
def get_model(opt): netC = None optimizerC = None schedulerC = None if opt.dataset == "cifar10" or opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device) if opt.dataset == "celeba": netC = ResNet18().to(opt.device) if opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) # Optimizer optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4) # Scheduler schedulerC = torch.optim.lr_scheduler.MultiStepLR( optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda) return netC, optimizerC, schedulerC
def _get_classifier(self, opt): if opt.dataset == "mnist": classifier = NetC_MNIST() elif opt.dataset == "cifar10": classifier = PreActResNet18() elif opt.dataset == "gtsrb": classifier = PreActResNet18(num_classes=43) else: raise Exception("Invalid Dataset") # Load pretrained classifier ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) if not os.path.exists(ckpt_folder): os.makedirs(ckpt_folder) ckpt_path = os.path.join(ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) state_dict = torch.load(ckpt_path) classifier.load_state_dict(state_dict["netC"]) for param in classifier.parameters(): param.requires_grad = False classifier.eval() return classifier.to(opt.device)
def main(): # Prepare arguments opt = get_arguments().parse_args() if opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid Dataset") mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") state_dict = torch.load(opt.ckpt_path) print("load C") netC.load_state_dict(state_dict["netC"]) netC.to(opt.device) netC.eval() netC.requires_grad_(False) print("load grid") identity_grid = state_dict["identity_grid"].to(opt.device) noise_grid = state_dict["noise_grid"].to(opt.device) print(state_dict["best_clean_acc"], state_dict["best_bd_acc"]) # Prepare dataloader test_dl = get_dataloader(opt, train=False) for name, module in netC._modules.items(): print(name) # Forward hook for getting layer's output container = [] def forward_hook(module, input, output): container.append(output) hook = netC.layer3.register_forward_hook(forward_hook) # Forwarding all the validation set print("Forwarding all the validation dataset:") for batch_idx, (inputs, _) in enumerate(test_dl): inputs = inputs.to(opt.device) netC(inputs) progress_bar(batch_idx, len(test_dl)) # Processing to get the "more important mask" container = torch.cat(container, dim=0) activation = torch.mean(container, dim=[0, 2, 3]) seq_sort = torch.argsort(activation) pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool) hook.remove() # Pruning times - no-tuning after pruning a channel!!! acc_clean = [] acc_bd = [] with open("mnist_{}_results.txt".format(opt.attack_mode), "w") as outs: for index in range(pruning_mask.shape[0]): net_pruned = copy.deepcopy(netC) num_pruned = index if index: channel = seq_sort[index - 1] pruning_mask[channel] = False print("Pruned {} filters".format(num_pruned)) net_pruned.layer3.conv1 = nn.Conv2d(pruning_mask.shape[0], pruning_mask.shape[0] - num_pruned, (3, 3), stride=2, padding=1, bias=False) net_pruned.linear6 = nn.Linear( (pruning_mask.shape[0] - num_pruned) * 16, 512) # Re-assigning weight to the pruned net for name, module in net_pruned._modules.items(): if "layer3" in name: module.conv1.weight.data = netC.layer3.conv1.weight.data[ pruning_mask] module.ind = pruning_mask elif "linear6" == name: module.weight.data = netC.linear6.weight.data.reshape( -1, 64, 16)[:, pruning_mask].reshape(512, -1) # [:, pruning_mask] module.bias.data = netC.linear6.bias.data else: continue net_pruned.to(opt.device) clean, bd = eval(net_pruned, identity_grid, noise_grid, test_dl, opt) outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
def main(): # Prepare arguments opt = get_arguments().parse_args() if (opt.dataset == 'mnist' or opt.dataset == 'cifar10'): opt.num_classes = 10 elif (opt.dataset == 'gtsrb'): opt.num_classes = 43 else: raise Exception("Invalid Dataset") if (opt.dataset == 'cifar10'): opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif (opt.dataset == 'gtsrb'): opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif (opt.dataset == 'mnist'): opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 else: raise Exception("Invalid Dataset") # Load models and masks if (opt.dataset == 'cifar10'): netC = PreActResNet18().to(opt.device) elif (opt.dataset == 'gtsrb'): netC = PreActResNet18(num_classes=43).to(opt.device) elif (opt.dataset == 'mnist'): netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid dataset") path_model = os.path.join( opt.checkpoints, opt.dataset, opt.attack_mode, '{}_{}_ckpt.pth.tar'.format(opt.attack_mode, opt.dataset)) state_dict = torch.load(path_model) print('load C') netC.load_state_dict(state_dict['netC']) netC.to(opt.device) netC.eval() netC.requires_grad_(False) print('load G') netG = Generator(opt) netG.load_state_dict(state_dict['netG']) netG.to(opt.device) netG.eval() netG.requires_grad_(False) netM = Generator(opt, out_channels=1) netM.load_state_dict(state_dict['netM']) netM.to(opt.device) netM.eval() netM.requires_grad_(False) # Prepare dataloader test_dl = get_dataloader(opt, train=False) print('Original') eval(netC, netG, netM, test_dl, opt) print('Smoothing') for k in [3, 5]: print('k = ', k) test_dl2 = get_dataloader(opt, train=False, k=k) eval(netC, netG, netM, test_dl2, opt) print('Color-depth shrinking') for cc in range(3): c = cc + 1 print('c = ', c) test_dl2 = get_dataloader(opt, train=False, c=c) eval(netC, netG, netM, test_dl2, opt)
def main(): # Prepare arguments opt = get_arguments().parse_args() if opt.dataset == "mnist": opt.num_classes = 10 else: raise Exception("Invalid Dataset") if opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 else: raise Exception("Invalid Dataset") # Load models if opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid dataset") path_model = os.path.join( opt.checkpoints, opt.dataset, opt.attack_mode, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset) ) state_dict = torch.load(path_model) netC.load_state_dict(state_dict["netC"]) netC.to(opt.device) netC.eval() netC.requires_grad_(False) netG = Generator(opt) netG.load_state_dict(state_dict["netG"]) netG.to(opt.device) netG.eval() netG.requires_grad_(False) netM = Generator(opt, out_channels=1) netM.load_state_dict(state_dict["netM"]) netM.to(opt.device) netM.eval() netM.requires_grad_(False) # Prepare dataloader test_dl = get_dataloader(opt, train=False) # Forward hook for getting layer's output container = [] def forward_hook(module, input, output): container.append(output) hook = netC.relu6.register_forward_hook(forward_hook) # Forwarding all the validation set print("Forwarding all the validation dataset:") for batch_idx, (inputs, _) in enumerate(test_dl): inputs = inputs.to(opt.device) netC(inputs) progress_bar(batch_idx, len(test_dl)) # Processing to get the "more important mask" container = torch.cat(container, dim=0) activation = torch.mean(container, dim=[0, 2, 3]) seq_sort = torch.argsort(activation) pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool) hook.remove() # Pruning times - no-tuning after pruning a channel!!! acc_clean = [] acc_bd = [] with open(opt.outfile, "w") as outs: for index in range(pruning_mask.shape[0]): net_pruned = copy.deepcopy(netC) num_pruned = index if index: channel = seq_sort[index] pruning_mask[channel] = False print("Pruned {} filters".format(num_pruned)) net_pruned.conv5 = nn.Conv2d(64, 64 - num_pruned, (5, 5), 1, 0) net_pruned.linear6 = nn.Linear(16 * (64 - num_pruned), 512) # Re-assigning weight to the pruned net for name, module in net_pruned._modules.items(): if "conv5" in name: module.weight.data = netC.conv5.weight.data[pruning_mask] module.bias.data = netC.conv5.bias.data[pruning_mask] elif "linear6" in name: module.weight.data = netC.linear6.weight.data.reshape(-1, 64, 16)[:, pruning_mask].reshape(512, -1) module.bias.data = netC.linear6.bias.data else: continue clean, bd = eval(net_pruned, netG, netM, test_dl, opt) outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
def train(opt): # Prepare model related things if opt.dataset == "cifar10": netC = PreActResNet18().to(opt.device) elif opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=43).to(opt.device) elif opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid dataset") netG = Generator(opt).to(opt.device) optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4) optimizerG = torch.optim.Adam(netG.parameters(), opt.lr_G, betas=(0.5, 0.9)) schedulerC = torch.optim.lr_scheduler.MultiStepLR( optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda) schedulerG = torch.optim.lr_scheduler.MultiStepLR( optimizerG, opt.schedulerG_milestones, opt.schedulerG_lambda) netM = Generator(opt, out_channels=1).to(opt.device) optimizerM = torch.optim.Adam(netM.parameters(), opt.lr_M, betas=(0.5, 0.9)) schedulerM = torch.optim.lr_scheduler.MultiStepLR( optimizerM, opt.schedulerM_milestones, opt.schedulerM_lambda) # For tensorboard log_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) if not os.path.exists(log_dir): os.makedirs(log_dir) log_dir = os.path.join(log_dir, "log_dir") if not os.path.exists(log_dir): os.makedirs(log_dir) tf_writer = SummaryWriter(log_dir=log_dir) # Continue training ? ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) ckpt_path = os.path.join( ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) if os.path.exists(ckpt_path): state_dict = torch.load(ckpt_path) netC.load_state_dict(state_dict["netC"]) netG.load_state_dict(state_dict["netG"]) netM.load_state_dict(state_dict["netM"]) epoch = state_dict["epoch"] + 1 optimizerC.load_state_dict(state_dict["optimizerC"]) optimizerG.load_state_dict(state_dict["optimizerG"]) schedulerC.load_state_dict(state_dict["schedulerC"]) schedulerG.load_state_dict(state_dict["schedulerG"]) best_acc_clean = state_dict["best_acc_clean"] best_acc_bd = state_dict["best_acc_bd"] best_acc_cross = state_dict["best_acc_cross"] opt = state_dict["opt"] print("Continue training") else: # Prepare mask best_acc_clean = 0.0 best_acc_bd = 0.0 best_acc_cross = 0.0 epoch = 1 # Reset tensorboard shutil.rmtree(log_dir) os.makedirs(log_dir) print("Training from scratch") # Prepare dataset train_dl1 = get_dataloader(opt, train=True) train_dl2 = get_dataloader(opt, train=True) test_dl1 = get_dataloader(opt, train=False) test_dl2 = get_dataloader(opt, train=False) if epoch == 1: netM.train() for i in range(25): print( "Epoch {} - {} - {} | mask_density: {} - lambda_div: {} - lambda_norm: {}:" .format(epoch, opt.dataset, opt.attack_mode, opt.mask_density, opt.lambda_div, opt.lambda_norm)) train_mask_step(netM, optimizerM, schedulerM, train_dl1, train_dl2, epoch, opt, tf_writer) epoch = eval_mask(netM, optimizerM, schedulerM, test_dl1, test_dl2, epoch, opt) epoch += 1 netM.eval() netM.requires_grad_(False) for i in range(opt.n_iters): print("Epoch {} - {} - {} | mask_density: {} - lambda_div: {}:".format( epoch, opt.dataset, opt.attack_mode, opt.mask_density, opt.lambda_div)) train_step( netC, netG, netM, optimizerC, optimizerG, schedulerC, schedulerG, train_dl1, train_dl2, epoch, opt, tf_writer, ) best_acc_clean, best_acc_bd, best_acc_cross, epoch = eval( netC, netG, netM, optimizerC, optimizerG, schedulerC, schedulerG, test_dl1, test_dl2, epoch, best_acc_clean, best_acc_bd, best_acc_cross, opt, ) epoch += 1 if epoch > opt.n_iters: break
def strip(opt, mode="clean"): # Prepare pretrained classifier if opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) elif opt.dataset == "cifar10" or opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device) elif opt.dataset == "celeba": netC = ResNet18().to(opt.device) else: raise Exception("Invalid dataset") # Load pretrained model mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") state_dict = torch.load(opt.ckpt_path) netC.load_state_dict(state_dict["netC"]) if mode != "clean": identity_grid = state_dict["identity_grid"] noise_grid = state_dict["noise_grid"] netC.requires_grad_(False) netC.eval() netC.to(opt.device) # Prepare test set testset = get_dataset(opt, train=False) opt.bs = opt.n_test test_dataloader = get_dataloader(opt, train=False) denormalizer = Denormalizer(opt) # STRIP detector strip_detector = STRIP(opt) # Entropy list list_entropy_trojan = [] list_entropy_benign = [] if mode == "attack": # Testing with perturbed data print("Testing with bd data !!!!") inputs, targets = next(iter(test_dataloader)) inputs = inputs.to(opt.device) bd_inputs = create_backdoor(inputs, identity_grid, noise_grid, opt) bd_inputs = denormalizer(bd_inputs) * 255.0 bd_inputs = bd_inputs.detach().cpu().numpy() bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose( (0, 2, 3, 1)) for index in range(opt.n_test): background = bd_inputs[index] entropy = strip_detector(background, testset, netC) list_entropy_trojan.append(entropy) progress_bar(index, opt.n_test) # Testing with clean data for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) else: # Testing with clean data print("Testing with clean data !!!!") for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) progress_bar(index, opt.n_test) return list_entropy_trojan, list_entropy_benign
def strip(opt, mode="clean"): if opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 elif opt.dataset == "cifar10": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "gtsrb": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 else: raise Exception("Invalid dataset") # Prepare pretrained classifier if opt.dataset == "mnist": netC = NetC_MNIST() elif opt.dataset == "cifar10": netC = PreActResNet18() else: netC = PreActResNet18(num_classes=43) for param in netC.parameters(): param.requires_grad = False netC.to(opt.device) netC.eval() if mode != "clean": netG = Generator(opt) for param in netG.parameters(): param.requires_grad = False netG.to(opt.device) netG.eval() # Load pretrained model ckpt_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) ckpt_path = os.path.join( ckpt_dir, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) state_dict = torch.load(ckpt_path) netC.load_state_dict(state_dict["netC"]) if mode != "clean": netG.load_state_dict(state_dict["netG"]) netM = Generator(opt, out_channels=1) netM.load_state_dict(state_dict["netM"]) netM.to(opt.device) netM.eval() netM.requires_grad_(False) # Prepare test set testset = get_dataset(opt, train=False) opt.bs = opt.n_test test_dataloader = get_dataloader(opt, train=False) # STRIP detector strip_detector = STRIP(opt) # Entropy list list_entropy_trojan = [] list_entropy_benign = [] if mode == "attack": # Testing with perturbed data print("Testing with bd data !!!!") inputs, targets = next(iter(test_dataloader)) inputs = inputs.to(opt.device) patterns = netG(inputs) patterns = netG.normalize_pattern(patterns) batch_masks = netM.threshold(netM(inputs)) bd_inputs = inputs + (patterns - inputs) * batch_masks bd_inputs = netG.denormalize_pattern(bd_inputs) * 255.0 bd_inputs = bd_inputs.detach().cpu().numpy() bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose( (0, 2, 3, 1)) for index in range(opt.n_test): background = bd_inputs[index] entropy = strip_detector(background, testset, netC) list_entropy_trojan.append(entropy) progress_bar(index, opt.n_test) # Testing with clean data for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) else: # Testing with clean data print("Testing with clean data !!!!") for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) progress_bar(index, opt.n_test) return list_entropy_trojan, list_entropy_benign