def _get_classifier(self, opt): if opt.dataset == "mnist": classifier = NetC_MNIST() elif opt.dataset == "cifar10": classifier = PreActResNet18() elif opt.dataset == "gtsrb": classifier = PreActResNet18(num_classes=43) else: raise Exception("Invalid Dataset") # Load pretrained classifier ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) if not os.path.exists(ckpt_folder): os.makedirs(ckpt_folder) ckpt_path = os.path.join(ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) state_dict = torch.load(ckpt_path) classifier.load_state_dict(state_dict["netC"]) for param in classifier.parameters(): param.requires_grad = False classifier.eval() return classifier.to(opt.device)
def get_model(opt): netC = None optimizerC = None schedulerC = None if opt.dataset == "cifar10" or opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device) if opt.dataset == "celeba": netC = ResNet18().to(opt.device) if opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) # Optimizer optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4) # Scheduler schedulerC = torch.optim.lr_scheduler.MultiStepLR( optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda) return netC, optimizerC, schedulerC
def train(opt): # Prepare model related things if opt.dataset == "cifar10": netC = PreActResNet18().to(opt.device) elif opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=43).to(opt.device) elif opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid dataset") netG = Generator(opt).to(opt.device) optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4) optimizerG = torch.optim.Adam(netG.parameters(), opt.lr_G, betas=(0.5, 0.9)) schedulerC = torch.optim.lr_scheduler.MultiStepLR( optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda) schedulerG = torch.optim.lr_scheduler.MultiStepLR( optimizerG, opt.schedulerG_milestones, opt.schedulerG_lambda) netM = Generator(opt, out_channels=1).to(opt.device) optimizerM = torch.optim.Adam(netM.parameters(), opt.lr_M, betas=(0.5, 0.9)) schedulerM = torch.optim.lr_scheduler.MultiStepLR( optimizerM, opt.schedulerM_milestones, opt.schedulerM_lambda) # For tensorboard log_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) if not os.path.exists(log_dir): os.makedirs(log_dir) log_dir = os.path.join(log_dir, "log_dir") if not os.path.exists(log_dir): os.makedirs(log_dir) tf_writer = SummaryWriter(log_dir=log_dir) # Continue training ? ckpt_folder = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) ckpt_path = os.path.join( ckpt_folder, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) if os.path.exists(ckpt_path): state_dict = torch.load(ckpt_path) netC.load_state_dict(state_dict["netC"]) netG.load_state_dict(state_dict["netG"]) netM.load_state_dict(state_dict["netM"]) epoch = state_dict["epoch"] + 1 optimizerC.load_state_dict(state_dict["optimizerC"]) optimizerG.load_state_dict(state_dict["optimizerG"]) schedulerC.load_state_dict(state_dict["schedulerC"]) schedulerG.load_state_dict(state_dict["schedulerG"]) best_acc_clean = state_dict["best_acc_clean"] best_acc_bd = state_dict["best_acc_bd"] best_acc_cross = state_dict["best_acc_cross"] opt = state_dict["opt"] print("Continue training") else: # Prepare mask best_acc_clean = 0.0 best_acc_bd = 0.0 best_acc_cross = 0.0 epoch = 1 # Reset tensorboard shutil.rmtree(log_dir) os.makedirs(log_dir) print("Training from scratch") # Prepare dataset train_dl1 = get_dataloader(opt, train=True) train_dl2 = get_dataloader(opt, train=True) test_dl1 = get_dataloader(opt, train=False) test_dl2 = get_dataloader(opt, train=False) if epoch == 1: netM.train() for i in range(25): print( "Epoch {} - {} - {} | mask_density: {} - lambda_div: {} - lambda_norm: {}:" .format(epoch, opt.dataset, opt.attack_mode, opt.mask_density, opt.lambda_div, opt.lambda_norm)) train_mask_step(netM, optimizerM, schedulerM, train_dl1, train_dl2, epoch, opt, tf_writer) epoch = eval_mask(netM, optimizerM, schedulerM, test_dl1, test_dl2, epoch, opt) epoch += 1 netM.eval() netM.requires_grad_(False) for i in range(opt.n_iters): print("Epoch {} - {} - {} | mask_density: {} - lambda_div: {}:".format( epoch, opt.dataset, opt.attack_mode, opt.mask_density, opt.lambda_div)) train_step( netC, netG, netM, optimizerC, optimizerG, schedulerC, schedulerG, train_dl1, train_dl2, epoch, opt, tf_writer, ) best_acc_clean, best_acc_bd, best_acc_cross, epoch = eval( netC, netG, netM, optimizerC, optimizerG, schedulerC, schedulerG, test_dl1, test_dl2, epoch, best_acc_clean, best_acc_bd, best_acc_cross, opt, ) epoch += 1 if epoch > opt.n_iters: break
def strip(opt, mode="clean"): if opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 elif opt.dataset == "cifar10": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "gtsrb": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 else: raise Exception("Invalid dataset") # Prepare pretrained classifier if opt.dataset == "mnist": netC = NetC_MNIST() elif opt.dataset == "cifar10": netC = PreActResNet18() else: netC = PreActResNet18(num_classes=43) for param in netC.parameters(): param.requires_grad = False netC.to(opt.device) netC.eval() if mode != "clean": netG = Generator(opt) for param in netG.parameters(): param.requires_grad = False netG.to(opt.device) netG.eval() # Load pretrained model ckpt_dir = os.path.join(opt.checkpoints, opt.dataset, opt.attack_mode) ckpt_path = os.path.join( ckpt_dir, "{}_{}_ckpt.pth.tar".format(opt.attack_mode, opt.dataset)) state_dict = torch.load(ckpt_path) netC.load_state_dict(state_dict["netC"]) if mode != "clean": netG.load_state_dict(state_dict["netG"]) netM = Generator(opt, out_channels=1) netM.load_state_dict(state_dict["netM"]) netM.to(opt.device) netM.eval() netM.requires_grad_(False) # Prepare test set testset = get_dataset(opt, train=False) opt.bs = opt.n_test test_dataloader = get_dataloader(opt, train=False) # STRIP detector strip_detector = STRIP(opt) # Entropy list list_entropy_trojan = [] list_entropy_benign = [] if mode == "attack": # Testing with perturbed data print("Testing with bd data !!!!") inputs, targets = next(iter(test_dataloader)) inputs = inputs.to(opt.device) patterns = netG(inputs) patterns = netG.normalize_pattern(patterns) batch_masks = netM.threshold(netM(inputs)) bd_inputs = inputs + (patterns - inputs) * batch_masks bd_inputs = netG.denormalize_pattern(bd_inputs) * 255.0 bd_inputs = bd_inputs.detach().cpu().numpy() bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose( (0, 2, 3, 1)) for index in range(opt.n_test): background = bd_inputs[index] entropy = strip_detector(background, testset, netC) list_entropy_trojan.append(entropy) progress_bar(index, opt.n_test) # Testing with clean data for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) else: # Testing with clean data print("Testing with clean data !!!!") for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) progress_bar(index, opt.n_test) return list_entropy_trojan, list_entropy_benign