def train(opt, init_mask, init_pattern): test_dataloader = get_dataloader(opt, train=False) # Build regression model regression_model = RegressionModel(opt, init_mask, init_pattern).to(opt.device) # Set optimizer optimizerR = torch.optim.Adam(regression_model.parameters(), lr=opt.lr, betas=(0.5, 0.9)) # Set recorder (for recording best result) recorder = Recorder(opt) for epoch in range(opt.epoch): early_stop = train_step(regression_model, optimizerR, test_dataloader, recorder, epoch, opt) if early_stop: break # Save result to dir recorder.save_result_to_dir(opt) return recorder, opt
def get_model(args): os.makedirs('DB_pretrained/CKD', exist_ok=True) total_idx = total_combine(args.Superclasses) train_loader, test_loader = get_dataloader(args, train_subidx=total_idx, test_subidx=total_idx) teacher = network.wresnet.wideresnet(depth=40, num_classes=100, widen_factor=4, dropRate=0.0) student_EX = network.wresnet.wideresnet_ex(depth=16, num_classes=100, widen_factor=1, dropRate=0.0) student_CL = network.wresnet.wideresnet_cl(depth=16, num_classes=len(total_idx), EX_widen_factor=1, widen_factor=(0.25*len(args.Superclasses)), dropRate=0.0) Oracle_path = './DB_pretrained/Oracle/Oracle_cifar100.pt' EX_path = './DB_pretrained/Library/library_cifar100.pt' CL_path = './DB_pretrained/CKD/CKD_CL_%s.pt' % args.Superclasses teacher.load_state_dict(torch.load(Oracle_path)) student_EX.load_state_dict(torch.load(EX_path)) teacher, student_EX = teacher.to(args.device), student_EX.to(args.device) teacher, student_EX = teacher.eval(), student_EX.eval() if args.model_pretrained is True: student_CL.load_state_dict(torch.load(CL_path)) student_CL = student_CL.to(args.device) student_CL = student_CL.eval() student = [student_EX, student_CL] best_acc = test(student, args.device, test_loader, 0, True) print("\nModel for %s Acc=%.2f%%" % (args.Superclasses, best_acc*100)) return student_CL = student_CL.to(args.device) optimizer_S = optim.SGD(list(student_EX.parameters())+list(student_CL.parameters()), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) if args.scheduler: scheduler_S = optim.lr_scheduler.MultiStepLR(optimizer_S, [40, 80], 0.1) best_acc = 0 student = [student_EX, student_CL] for epoch in range(1, args.model_epochs + 1): if args.scheduler: scheduler_S.step() train(args, teacher=teacher, student=student, device=args.device, train_loader=train_loader, optimizer=optimizer_S, epoch=epoch, total_idx=total_idx) acc = test(student, args.device, test_loader, epoch) if acc > best_acc: best_acc = acc torch.save(student_CL.state_dict(), CL_path) print("\nModel for %s Acc=%.2f%%" % (args.Superclasses, best_acc*100))
def get_experts(args, Oracle, library): os.makedirs('DB_Pool of Experts/Experts', exist_ok=True) for primitiveTask in CIFAR100_Superclass.keys(): priTask_idx = CIFAR100_Superclass[primitiveTask] train_loader, test_loader = get_dataloader(args, test_subidx=priTask_idx) expert = network.wresnet.wideresnet_cl(depth=16, num_classes=len(priTask_idx), EX_widen_factor=1, widen_factor=0.25, dropRate=0.0) priTask_path = './DB_Pool of Experts/Experts/expert_%s.pt' % primitiveTask if args.Experts_pretrained is True: if not os.path.exists(priTask_path): continue expert.load_state_dict(torch.load(priTask_path)) expert = expert.to(args.device) library.eval() expert.eval() student = [library, expert] best_acc = test(student, args.device, test_loader, 0, True) print("\nModel for %s Acc=%.2f%%" % (primitiveTask, best_acc*100)) continue expert = expert.to(args.device) Oracle.eval() library.eval() optimizer_S = optim.SGD(expert.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) if args.scheduler: scheduler_S = optim.lr_scheduler.MultiStepLR(optimizer_S, [40, 80], 0.1) best_acc = 0 student = [library, expert] for epoch in range(1, args.Experts_epochs + 1): if args.scheduler: scheduler_S.step() train(args, Oracle=Oracle, student=student, device=args.device, train_loader=train_loader, optimizer=optimizer_S, epoch=epoch, priTask_idx=priTask_idx) acc = test(student, args.device, test_loader, epoch) if acc > best_acc: best_acc = acc torch.save(expert.state_dict(), priTask_path) print("\nModel for %s Acc=%.2f%%" % (primitiveTask, best_acc*100))
def get_Oracle(args): os.makedirs('DB_pretrained/Oracle', exist_ok=True) train_loader, test_loader = get_dataloader(args) model = network.wresnet.wideresnet(depth=40, num_classes=args.Oracle_classes, widen_factor=4, dropRate=0.0) if args.Oracle_pretrained is True: model.load_state_dict( torch.load('./DB_pretrained/Oracle/Oracle_cifar100.pt')) model = model.to(args.device) model.eval() best_acc = test(args, model, args.device, test_loader, 0, True) print("\nOracle Acc for all classes=%.2f%%" % (best_acc * 100)) return model model = model.to(args.device) optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) best_acc = 0 if args.scheduler: scheduler = optim.lr_scheduler.StepLR(optimizer, args.Oracle_step_size, 0.1) for epoch in range(1, args.Oracle_epochs + 1): if args.scheduler: scheduler.step() train(args, model, args.device, train_loader, optimizer, epoch) acc = test(args, model, args.device, test_loader, epoch) if acc > best_acc: best_acc = acc torch.save(model.state_dict(), args.ckpt) print("\nOracle Acc for all classes = %.2f%%" % (best_acc * 100)) model.load_state_dict( torch.load('./DB_pretrained/Oracle/Oracle_cifar100.pt')) model.eval() return model
def get_MQ(args): total_idx = total_combine(args.queriedTask) idx_dict = idx_search(args.queriedTask) _, test_loader = get_dataloader(args, test_subidx=total_idx) library = network.wresnet.wideresnet_ex(depth=16, num_classes=args.Oracle_classes, widen_factor=1, dropRate=0.0) library.load_state_dict( torch.load('./DB_Pool of Experts/Library/library_cifar100.pt')) library = library.to(args.device) library.eval() experts = [] for primitiveTask in args.queriedTask: priTask_idx = CIFAR100_Superclass[primitiveTask] priTask_path = './DB_Pool of Experts/Experts/expert_%s.pt' % primitiveTask expert = network.wresnet.wideresnet_cl(depth=16, num_classes=len(priTask_idx), EX_widen_factor=1, widen_factor=0.25, dropRate=0.0) expert.load_state_dict(torch.load(priTask_path)) expert = expert.to(args.device) expert.eval() experts.append(expert) model_MQ = network.wresnet.wideresnet_MQ(library=library, experts=experts) best_acc = test(model_MQ, args.device, test_loader, idx_dict, args.queriedTask) print("\nModel_MQ Acc=%.2f%%" % (best_acc * 100)) return model_MQ
def create_hooks(args, model, optimizer, losses, logger, serializer): device = torch.device(args.device) loader = get_dataloader(get_valset_params(args)) hooks = { 'serialization': SerializationHook(serializer, model, optimizer, logger) } periods = {'serialization': args.checkpointing_interval} if not args.skip_validation: # only raw events can be used for validation hooks['validation'] = ValidationHook(model, device, loader, logger, losses, weights=args.loss_weights, is_raw=True) periods['validation'] = args.vp periodic_hooks = { k: make_hook_periodic(hooks[k], periods[k]) for k in periods } return periodic_hooks, hooks
def main(): image_queue = Queue() num_writers = cpu_count() worker = Pool(num_writers, image_writer, (image_queue, )) args = parse_args(sys.argv[1:]) args.mbs = 1 output_dir = choose_output_path(args) model = init_model(args, args.device) model.eval() loader = get_dataloader(get_valset_params(args)) evaluator = init_losses(args.shape, 1, model, args.device, sequence_length=args.prefix_length + args.suffix_length + 1) with torch.no_grad(): for i, batch in tqdm(enumerate(loader), total=len(loader)): output_file_path = output_dir / f'{i:04d}' if all(x.is_file() for x in files(output_file_path)): continue loss, parts, tags, prediction = process_minibatch( model, batch, FakeTimer(), args.device, args.is_raw, evaluator, args.loss_weights, return_prediction=True) visualization, stat = visualize(args, batch, loss, parts, args.loss_weights, prediction) image_queue.put((output_file_path, visualization, stat)) for _ in range(num_writers): image_queue.put(None) worker.close() worker.join()
def main(): # Prepare arguments opt = get_arguments().parse_args() if opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 netC = NetC_MNIST().to(opt.device) else: raise Exception("Invalid Dataset") mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") state_dict = torch.load(opt.ckpt_path) print("load C") netC.load_state_dict(state_dict["netC"]) netC.to(opt.device) netC.eval() netC.requires_grad_(False) print("load grid") identity_grid = state_dict["identity_grid"].to(opt.device) noise_grid = state_dict["noise_grid"].to(opt.device) print(state_dict["best_clean_acc"], state_dict["best_bd_acc"]) # Prepare dataloader test_dl = get_dataloader(opt, train=False) for name, module in netC._modules.items(): print(name) # Forward hook for getting layer's output container = [] def forward_hook(module, input, output): container.append(output) hook = netC.layer3.register_forward_hook(forward_hook) # Forwarding all the validation set print("Forwarding all the validation dataset:") for batch_idx, (inputs, _) in enumerate(test_dl): inputs = inputs.to(opt.device) netC(inputs) progress_bar(batch_idx, len(test_dl)) # Processing to get the "more important mask" container = torch.cat(container, dim=0) activation = torch.mean(container, dim=[0, 2, 3]) seq_sort = torch.argsort(activation) pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool) hook.remove() # Pruning times - no-tuning after pruning a channel!!! acc_clean = [] acc_bd = [] with open("mnist_{}_results.txt".format(opt.attack_mode), "w") as outs: for index in range(pruning_mask.shape[0]): net_pruned = copy.deepcopy(netC) num_pruned = index if index: channel = seq_sort[index - 1] pruning_mask[channel] = False print("Pruned {} filters".format(num_pruned)) net_pruned.layer3.conv1 = nn.Conv2d(pruning_mask.shape[0], pruning_mask.shape[0] - num_pruned, (3, 3), stride=2, padding=1, bias=False) net_pruned.linear6 = nn.Linear( (pruning_mask.shape[0] - num_pruned) * 16, 512) # Re-assigning weight to the pruned net for name, module in net_pruned._modules.items(): if "layer3" in name: module.conv1.weight.data = netC.layer3.conv1.weight.data[ pruning_mask] module.ind = pruning_mask elif "linear6" == name: module.weight.data = netC.linear6.weight.data.reshape( -1, 64, 16)[:, pruning_mask].reshape(512, -1) # [:, pruning_mask] module.bias.data = netC.linear6.bias.data else: continue net_pruned.to(opt.device) clean, bd = eval(net_pruned, identity_grid, noise_grid, test_dl, opt) outs.write("%d %0.4f %0.4f\n" % (index, clean, bd))
def main(): # torch.autograd.set_detect_anomaly(True) args = parse_args(sys.argv[1:]) device = torch.device(args.device) if device.type == 'cuda': torch.cuda.set_device(device) if args.timers: timers = SynchronizedWallClockTimer() else: timers = FakeTimer() model = init_model(args, device) serializer = Serializer(args.model, args.num_checkpoints, args.permanent_interval) args.do_not_continue = (args.do_not_continue or len(serializer.list_known_steps()) == 0) last_step = (0 if args.do_not_continue else serializer.list_known_steps()[-1]) optimizer, scheduler = construct_train_tools(args, model, passed_steps=last_step) losses = init_losses(args.shape, args.bs, model, device, sequence_length=args.prefix_length + args.suffix_length + 1, timers=timers) # allow only manual flush logger = SummaryWriter(str(args.log_path), max_queue=100000000, flush_secs=100000000) periodic_hooks, hooks = create_hooks(args, model, optimizer, losses, logger, serializer) if not args.do_not_continue: global_step, state = serializer.load_checkpoint(model, last_step, optimizer=optimizer, device=device) samples_passed = state.pop('samples_passed', global_step * args.bs) else: global_step = 0 samples_passed = 0 hooks['serialization'](global_step, samples_passed) loader = get_dataloader(get_trainset_params(args), sample_idx=samples_passed, process_only_once=False) if not args.skip_validation: hooks['validation'](global_step, samples_passed) with Profiler(args.profiling, args.model/'profiling'), \ GPUMonitor(args.log_path): train(model, device, loader, optimizer, args.training_steps, scheduler=scheduler, evaluator=losses, logger=logger, weights=args.loss_weights, is_raw=args.is_raw, accumulation_steps=args.accum_step, timers=timers, hooks=periodic_hooks, init_step=global_step, init_samples_passed=samples_passed, max_events_per_batch=args.max_events_per_batch) samples = samples_passed + (args.training_steps - global_step) * args.bs hooks['serialization'](args.training_steps, samples) if not args.skip_validation: hooks['validation'](args.training_steps, samples)
def get_model(args): os.makedirs('DB_pretrained/UHC_Scratch', exist_ok=True) total_idx = total_combine(args.Superclasses) idx_dict = idx_search(args.Superclasses) train_loader, test_loader = get_dataloader(args, train_subidx=total_idx, test_subidx=total_idx) student_EX = network.wresnet.wideresnet_ex(depth=16, num_classes=100, widen_factor=1, dropRate=0.0) student_CL = network.wresnet.wideresnet_cl( depth=16, num_classes=len(total_idx), EX_widen_factor=1, widen_factor=(0.25 * len(args.Superclasses)), dropRate=0.0) student_EX_path = './DB_pretrained/UHC_Scratch/UHC_Scratch_EX_%s.pt' % args.Superclasses student_CL_path = './DB_pretrained/UHC_Scratch/UHC_Scratch_CL_%s.pt' % args.Superclasses if args.model_pretrained is True: student_EX.load_state_dict(torch.load(student_EX_path)) student_CL.load_state_dict(torch.load(student_CL_path)) student_EX, student_CL = student_EX.to(args.device), student_CL.to( args.device) student_EX, student_CL = student_EX.eval(), student_CL.eval() student = [student_EX, student_CL] best_acc = test(student, args.device, test_loader, 0, True) print("\nModel for %s Acc=%.2f%%" % (args.Superclasses, best_acc * 100)) return student_EX, student_CL = student_EX.to(args.device), student_CL.to( args.device) Scratch_EX = [] for s in args.Superclasses: Scratch_EX_path = './DB_pretrained/Scratch/Scratch_EX_%s.pt' % [s] e = network.wresnet.wideresnet_ex(depth=16, num_classes=100, widen_factor=1, dropRate=0.0) e.load_state_dict(torch.load(Scratch_EX_path)) e = e.to(args.device) e.eval() Scratch_EX.append(e) Scratch_CL = [] for s in args.Superclasses: Scratch_CL_path = './DB_pretrained/Scratch/Scratch_CL_%s.pt' % [s] e = network.wresnet.wideresnet_cl(depth=16, num_classes=5, EX_widen_factor=1, widen_factor=0.25, dropRate=0.0) e.load_state_dict(torch.load(Scratch_CL_path)) e = e.to(args.device) e.eval() Scratch_CL.append(e) teacher = [] for i in zip(Scratch_EX, Scratch_CL): teacher.append(i) optimizer_S = optim.SGD(list(student_EX.parameters()) + list(student_CL.parameters()), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) if args.scheduler: scheduler_S = optim.lr_scheduler.MultiStepLR(optimizer_S, [80, 160], 0.1) best_acc = 0 student = [student_EX, student_CL] for epoch in range(1, args.model_epochs + 1): if args.scheduler: scheduler_S.step() train(args, teacher=teacher, student=student, device=args.device, train_loader=train_loader, optimizer=optimizer_S, epoch=epoch, idx_dict=idx_dict) acc = test(student, args.device, test_loader, epoch) if acc > best_acc: best_acc = acc torch.save(student_EX.state_dict(), student_EX_path) torch.save(student_CL.state_dict(), student_CL_path) print("\nModel for %s Acc=%.2f%%" % (args.Superclasses, best_acc * 100))
def get_library(args, Oracle): os.makedirs('DB_Pool of Experts/Library', exist_ok=True) train_loader, test_loader = get_dataloader(args) student_library = network.wresnet.wideresnet_ex( depth=16, num_classes=args.Oracle_classes, widen_factor=1, dropRate=0.0) student_classifier = network.wresnet.wideresnet_cl( depth=16, num_classes=args.Oracle_classes, EX_widen_factor=1, widen_factor=1, dropRate=0.0) if args.Library_pretrained is True: student_library.load_state_dict( torch.load('./DB_Pool of Experts/Library/library_cifar100.pt')) student_classifier.load_state_dict( torch.load( './DB_pretrained/student for library/classifier_cifar100.pt')) student_library = student_library.to(args.device) student_classifier = student_classifier.to(args.device) student_library.eval() student_classifier.eval() student = [student_library, student_classifier] best_acc = test(student, args.device, test_loader, 0, True) print("\nStudent for Library Acc=%.2f%%" % (best_acc * 100)) return student_library student_library = student_library.to(args.device) student_classifier = student_classifier.to(args.device) Oracle.eval() optimizer_S = optim.SGD(list(student_library.parameters()) + list(student_classifier.parameters()), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) if args.scheduler: scheduler_S = optim.lr_scheduler.MultiStepLR(optimizer_S, [80, 160], 0.1) best_acc = 0 student = [student_library, student_classifier] for epoch in range(1, args.Oracle_epochs + 1): if args.scheduler: scheduler_S.step() train(args, Oracle=Oracle, student=student, device=args.device, train_loader=train_loader, optimizer=optimizer_S, epoch=epoch) acc = test(student, args.device, test_loader, epoch) if acc > best_acc: best_acc = acc torch.save(student_library.state_dict(), './DB_Pool of Experts/Library/library_cifar100.pt') torch.save( student_classifier.state_dict(), './DB_pretrained/student for library/classifier_cifar100.pt') print("\nStudent for Library Acc=%.2f%%" % (best_acc * 100)) student_library.load_state_dict( torch.load('./DB_Pool of Experts/Library/library_cifar100.pt')) student_library.eval() return student_library
def strip(opt, mode="clean"): # Prepare pretrained classifier if opt.dataset == "mnist": netC = NetC_MNIST().to(opt.device) elif opt.dataset == "cifar10" or opt.dataset == "gtsrb": netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device) elif opt.dataset == "celeba": netC = ResNet18().to(opt.device) else: raise Exception("Invalid dataset") # Load pretrained model mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") state_dict = torch.load(opt.ckpt_path) netC.load_state_dict(state_dict["netC"]) if mode != "clean": identity_grid = state_dict["identity_grid"] noise_grid = state_dict["noise_grid"] netC.requires_grad_(False) netC.eval() netC.to(opt.device) # Prepare test set testset = get_dataset(opt, train=False) opt.bs = opt.n_test test_dataloader = get_dataloader(opt, train=False) denormalizer = Denormalizer(opt) # STRIP detector strip_detector = STRIP(opt) # Entropy list list_entropy_trojan = [] list_entropy_benign = [] if mode == "attack": # Testing with perturbed data print("Testing with bd data !!!!") inputs, targets = next(iter(test_dataloader)) inputs = inputs.to(opt.device) bd_inputs = create_backdoor(inputs, identity_grid, noise_grid, opt) bd_inputs = denormalizer(bd_inputs) * 255.0 bd_inputs = bd_inputs.detach().cpu().numpy() bd_inputs = np.clip(bd_inputs, 0, 255).astype(np.uint8).transpose( (0, 2, 3, 1)) for index in range(opt.n_test): background = bd_inputs[index] entropy = strip_detector(background, testset, netC) list_entropy_trojan.append(entropy) progress_bar(index, opt.n_test) # Testing with clean data for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) else: # Testing with clean data print("Testing with clean data !!!!") for index in range(opt.n_test): background, _ = testset[index] entropy = strip_detector(background, testset, netC) list_entropy_benign.append(entropy) progress_bar(index, opt.n_test) return list_entropy_trojan, list_entropy_benign
def train_dataloader(self): trn_dataloader, _, _ = get_dataloader(self.hparams) return trn_dataloader
def main(): opt = config.get_arguments().parse_args() if opt.dataset in ["mnist", "cifar10"]: opt.num_classes = 10 elif opt.dataset == "gtsrb": opt.num_classes = 43 elif opt.dataset == "celeba": opt.num_classes = 8 else: raise Exception("Invalid Dataset") if opt.dataset == "cifar10": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "gtsrb": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 elif opt.dataset == "celeba": opt.input_height = 64 opt.input_width = 64 opt.input_channel = 3 else: raise Exception("Invalid Dataset") # Dataset train_dl = get_dataloader(opt, True) test_dl = get_dataloader(opt, False) # prepare model netC, optimizerC, schedulerC = get_model(opt) # Load pretrained model mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") if not os.path.exists(opt.log_dir): os.makedirs(opt.log_dir) if opt.continue_training: if os.path.exists(opt.ckpt_path): print("Continue training!!") state_dict = torch.load(opt.ckpt_path) netC.load_state_dict(state_dict["netC"]) optimizerC.load_state_dict(state_dict["optimizerC"]) schedulerC.load_state_dict(state_dict["schedulerC"]) best_clean_acc = state_dict["best_clean_acc"] best_bd_acc = state_dict["best_bd_acc"] best_cross_acc = state_dict["best_cross_acc"] epoch_current = state_dict["epoch_current"] identity_grid = state_dict["identity_grid"] noise_grid = state_dict["noise_grid"] tf_writer = SummaryWriter(log_dir=opt.log_dir) else: print("Pretrained model doesnt exist") exit() else: print("Train from scratch!!!") best_clean_acc = 0.0 best_bd_acc = 0.0 best_cross_acc = 0.0 epoch_current = 0 # Prepare grid ins = torch.rand(1, 2, opt.k, opt.k) * 2 - 1 ins = ins / torch.mean(torch.abs(ins)) noise_grid = (F.upsample(ins, size=opt.input_height, mode="bicubic", align_corners=True).permute(0, 2, 3, 1).to(opt.device)) array1d = torch.linspace(-1, 1, steps=opt.input_height) x, y = torch.meshgrid(array1d, array1d) identity_grid = torch.stack((y, x), 2)[None, ...].to(opt.device) shutil.rmtree(opt.ckpt_folder, ignore_errors=True) os.makedirs(opt.log_dir) with open(os.path.join(opt.ckpt_folder, "opt.json"), "w+") as f: json.dump(opt.__dict__, f, indent=2) tf_writer = SummaryWriter(log_dir=opt.log_dir) for epoch in range(epoch_current, opt.n_iters): print("Epoch {}:".format(epoch + 1)) train(netC, optimizerC, schedulerC, train_dl, noise_grid, identity_grid, tf_writer, epoch, opt) best_clean_acc, best_bd_acc, best_cross_acc = eval( netC, optimizerC, schedulerC, test_dl, noise_grid, identity_grid, best_clean_acc, best_bd_acc, best_cross_acc, tf_writer, epoch, opt, )
def main(): opt = config.get_arguments().parse_args() if opt.dataset in ["mnist", "cifar10"]: opt.num_classes = 10 elif opt.dataset == "gtsrb": opt.num_classes = 43 elif opt.dataset == "celeba": opt.num_classes = 8 else: raise Exception("Invalid Dataset") if opt.dataset == "cifar10": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "gtsrb": opt.input_height = 32 opt.input_width = 32 opt.input_channel = 3 elif opt.dataset == "mnist": opt.input_height = 28 opt.input_width = 28 opt.input_channel = 1 elif opt.dataset == "celeba": opt.input_height = 64 opt.input_width = 64 opt.input_channel = 3 else: raise Exception("Invalid Dataset") # Dataset test_dl = get_dataloader(opt, False) # prepare model netC, optimizerC, schedulerC = get_model(opt) # Load pretrained model mode = opt.attack_mode opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset) opt.ckpt_path = os.path.join( opt.ckpt_folder, "{}_{}_morph.pth.tar".format(opt.dataset, mode)) opt.log_dir = os.path.join(opt.ckpt_folder, "log_dir") if os.path.exists(opt.ckpt_path): state_dict = torch.load(opt.ckpt_path) netC.load_state_dict(state_dict["netC"]) identity_grid = state_dict["identity_grid"] noise_grid = state_dict["noise_grid"] else: print("Pretrained model doesnt exist") exit() eval( netC, optimizerC, schedulerC, test_dl, noise_grid, identity_grid, opt, )