def main(): model = StandardModel(args.dataset, args.arch, no_grad=False, load_pretrained=False) model.cuda() model.train() device = torch.device("cuda") optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model_path = '{}/train_pytorch_model/adversarial_train/TRADES/{}@{}@epoch_{}@batch_{}.pth.tar'.format( PY_ROOT, args.dataset, args.arch, args.epochs, args.batch_size) os.makedirs(os.path.dirname(model_path), exist_ok=True) print("After trained, the model will save to {}".format(model_path)) for epoch in range(1, args.epochs + 1): # adjust learning rate for SGD adjust_learning_rate(optimizer, epoch) # adversarial training train(args, model, device, train_loader, optimizer, epoch) # evaluation on natural examples print( '================================================================') eval_train(model, device, train_loader) eval_test(model, device, test_loader) print( '================================================================') # save checkpoint if epoch % args.save_freq == 0: state = { 'state_dict': model.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict() } torch.save(state, os.path.join(model_dir, model_path))
def __init__(self, dataset, batch_size, meta_arch, meta_train_type, distill_loss, data_loss, norm, targeted, use_softmax, mode="meta"): if mode == "meta": target_str = "targeted_attack_random" if targeted else "untargeted_attack" # 2Q_DISTILLATION@CIFAR-100@TRAIN_I_TEST_II@model_resnet34@loss_pair_mse@dataloss_cw_l2_untargeted_attack@epoch_4@meta_batch_size_30@num_support_50@num_updates_12@lr_0.001@inner_lr_0.01.pth.tar self.meta_model_path = "{root}/train_pytorch_model/meta_simulator/{meta_train_type}@{dataset}@{split}@model_{meta_arch}@loss_{loss}@dataloss_{data_loss}_{norm}_{target_str}*inner_lr_0.01.pth.tar".format( root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset, split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, meta_arch=meta_arch, loss=distill_loss, data_loss=data_loss, norm=norm, target_str=target_str) log.info("start using {}".format(self.meta_model_path)) self.meta_model_path = glob.glob(self.meta_model_path) pattern = re.compile(".*model_(.*?)@.*inner_lr_(.*?)\.pth.*") assert len(self.meta_model_path) > 0 self.meta_model_path = self.meta_model_path[0] log.info("load meta model {}".format(self.meta_model_path)) ma = pattern.match(os.path.basename(self.meta_model_path)) arch = ma.group(1) self.inner_lr = float(ma.group(2)) meta_backbone = self.construct_model(arch, dataset) self.meta_network = MetaNetwork(meta_backbone) self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage) self.meta_network.load_state_dict(self.pretrained_weights["state_dict"]) log.info("Load model in epoch {}.".format(self.pretrained_weights["epoch"])) self.pretrained_weights = self.pretrained_weights["state_dict"] elif mode == "vanilla": target_str = "targeted" if targeted else "untargeted" arch = meta_arch # 2Q_DISTILLATION@CIFAR-100@TRAIN_I_TEST_II@model_resnet34@loss_pair_mse@dataloss_cw_l2_untargeted_attack@epoch_4@meta_batch_size_30@num_support_50@num_updates_12@lr_0.001@inner_lr_0.01.pth.tar self.meta_model_path = "{root}/train_pytorch_model/vanilla_simulator/{dataset}@{norm}_norm_{target_str}@{meta_arch}*.tar".format( root=PY_ROOT, dataset=dataset, meta_arch=meta_arch,norm=norm, target_str=target_str) log.info("start using {}".format(self.meta_model_path)) self.meta_model_path = glob.glob(self.meta_model_path) assert len(self.meta_model_path) > 0 self.meta_model_path = self.meta_model_path[0] log.info("load meta model {}".format(self.meta_model_path)) self.inner_lr = 0.01 self.meta_network = self.construct_model(meta_arch, dataset) self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage) log.info("Load model in epoch {}.".format(self.pretrained_weights["epoch"])) self.pretrained_weights = self.pretrained_weights["state_dict"] elif mode == "deep_benign_images": arch = "resnet34" self.inner_lr = 0.01 self.meta_network = self.construct_model(arch, dataset) self.meta_model_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}@epoch_200@lr_0.1@batch_200.pth.tar".format( root=PY_ROOT, dataset=dataset, arch=arch) assert os.path.exists(self.meta_model_path), "{} does not exists!".format(self.meta_model_path) self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)[ "state_dict"] elif mode == "random_init": arch = "resnet34" self.inner_lr = 0.01 self.meta_network = self.construct_model(arch, dataset) self.pretrained_weights = self.meta_network.state_dict() elif mode == 'ensemble_avg': self.inner_lr = 0.01 self.archs = ["densenet-bc-100-12","resnet-110","vgg19_bn"] self.meta_network = [] # meta_network和pretrained_weights都改成list self.pretrained_weights = [] for arch in self.archs: model = StandardModel(dataset, arch, no_grad=False, load_pretrained=True) model.eval() model.cuda() self.meta_network.append(model) self.pretrained_weights.append(model.state_dict()) elif mode == "benign_images": self.inner_lr = 0.01 self.meta_model_path = "{root}/train_pytorch_model/meta_simulator_on_benign_images/{dataset}@{split}*@inner_lr_0.01.pth.tar".format( root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset, split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II) self.meta_model_path = glob.glob(self.meta_model_path) pattern = re.compile(".*model_(.*?)@.*") assert len(self.meta_model_path) > 0 self.meta_model_path = self.meta_model_path[0] ma = pattern.match(os.path.basename(self.meta_model_path)) log.info("Loading meta model from {}".format(self.meta_model_path)) arch = ma.group(1) self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)["state_dict"] meta_backbone = self.construct_model(arch, dataset) self.meta_network = MetaNetwork(meta_backbone) self.meta_network.load_state_dict(self.pretrained_weights) self.meta_network.eval() self.meta_network.cuda() elif mode == "reptile_on_benign_images": self.inner_lr = 0.01 self.meta_model_path = "{root}/train_pytorch_model/meta_simulator_reptile_on_benign_images/{dataset}@{split}*@inner_lr_0.01.pth.tar".format( root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset, split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II) self.meta_model_path = glob.glob(self.meta_model_path) pattern = re.compile(".*model_(.*?)@.*") assert len(self.meta_model_path) > 0 self.meta_model_path = self.meta_model_path[0] log.info("Loading meta model from {}".format(self.meta_model_path)) ma = pattern.match(os.path.basename(self.meta_model_path)) arch = ma.group(1) self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)["state_dict"] meta_backbone = self.construct_model(arch, dataset) self.meta_network = MetaNetwork(meta_backbone) self.meta_network.load_state_dict(self.pretrained_weights) self.meta_network.eval() self.meta_network.cuda() self.arch = arch self.dataset = dataset self.need_pair_distance = (distill_loss.lower()=="pair_mse") # self.need_pair_distance = False self.softmax = nn.Softmax(dim=1) self.mse_loss = nn.MSELoss(reduction="mean") self.pair_wise_distance = nn.PairwiseDistance(p=2) self.use_softmax = use_softmax if mode != "ensemble_avg": self.meta_network.load_state_dict(self.pretrained_weights) self.meta_network.eval() self.meta_network.cuda() self.batch_size = batch_size if mode == 'ensemble_avg': self.batch_weights = defaultdict(dict) for idx in range(len(self.pretrained_weights)): for i in range(batch_size): self.batch_weights[idx][i] = self.pretrained_weights[idx] else: self.batch_weights = dict() for i in range(batch_size): self.batch_weights[i] = self.pretrained_weights
model.zero_grad() loss.backward() optimizer.step() logger.info('[%d] train loss: adv: %.3f, clean: %.3f' % (epoch + 1, running_loss_1 / i, running_loss_2 / i)) if epoch % EVALUATE_EPOCH == 0: running_loss, correct, total = 0.0, 0.0, 0.0 model.eval() for i, data_batch in enumerate(val_loader): # get the inputs; data is a list of [inputs, labels] img_batch, label_batch = data_batch img_batch, label_batch = img_batch.cuda(), label_batch.cuda() output_batch = model(img_batch) loss = criterion(output_batch, label_batch) running_loss += loss.item() _, predicted = torch.max(output_batch.data, 1) _, label_ind = torch.max(label_batch.data, 1) correct += (predicted == label_ind).sum().item() total += label_batch.size(0) logger.info('[%d] test loss: %.3f, accuracy: %.3f' % (epoch + 1, running_loss / i, correct / total)) if epoch % args.save_epoch == 0 or epoch == EPOCH_TOTAL - 1: torch.save(model.state_dict(), os.path.join(MODELS_FOLDER, "eopch{}.ckpt".format(epoch))) logger.info('Finished Training')
def main(): args = get_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) logger.info(args) model_path = '{}/train_pytorch_model/adversarial_train/fast_adv_train/{}@{}@epoch_{}.pth.tar'.format( PY_ROOT, args.dataset, args.arch, args.epochs) out_dir = os.path.dirname(model_path) os.makedirs(out_dir, exist_ok=True) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) start_start_time = time.time() train_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, True) test_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, False) epsilon = (args.epsilon / 255.) / std pgd_alpha = (args.pgd_alpha / 255.) / std model = StandardModel(args.dataset, args.arch, no_grad=False) model.apply(initialize_weights) model.cuda() model.train() opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=5e-4) model, opt = amp.initialize(model, opt, opt_level="O2", loss_scale=1.0, master_weights=False) criterion = nn.CrossEntropyLoss() if args.attack == 'free': delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() delta.requires_grad = True elif args.attack == 'fgsm' and args.fgsm_init == 'previous': delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() delta.requires_grad = True if args.attack == 'free': assert args.epochs % args.attack_iters == 0 epochs = int(math.ceil(args.epochs / args.attack_iters)) else: epochs = args.epochs if args.lr_schedule == 'cyclic': lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0] elif args.lr_schedule == 'piecewise': def lr_schedule(t): if t / args.epochs < 0.5: return args.lr_max elif t / args.epochs < 0.75: return args.lr_max / 10. else: return args.lr_max / 100. prev_robust_acc = 0. logger.info('Epoch \t Time \t LR \t \t Train Loss \t Train Acc') for epoch in range(epochs): start_time = time.time() train_loss = 0 train_acc = 0 train_n = 0 for i, (X, y) in enumerate(train_loader): X = X.cuda().float() y = y.cuda().long() if i == 0: first_batch = X, y lr = lr_schedule(epoch + (i + 1) / len(train_loader)) opt.param_groups[0].update(lr=lr) if args.attack == 'pgd': delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, opt) elif args.attack == 'fgsm': if args.fgsm_init == 'zero': delta = torch.zeros_like(X, requires_grad=True) delta.requires_grad = True elif args.fgsm_init == 'random': delta = torch.zeros_like(X).cuda() delta[:, 0, :, :].uniform_(-epsilon[0][0][0].item(), epsilon[0][0][0].item()) delta[:, 1, :, :].uniform_(-epsilon[1][0][0].item(), epsilon[1][0][0].item()) delta[:, 2, :, :].uniform_(-epsilon[2][0][0].item(), epsilon[2][0][0].item()) delta.requires_grad = True elif args.fgsm_init == 'previous': delta.requires_grad = True output = model(X + delta[:X.size(0)]) loss = F.cross_entropy(output, y) with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() grad = delta.grad.detach() delta.data = clamp(delta + args.fgsm_alpha * epsilon * torch.sign(grad), -epsilon, epsilon) delta = delta.detach() elif args.attack == 'free': delta.requires_grad = True for j in range(args.attack_iters): epoch_iters = epoch * args.attack_iters + (i * args.attack_iters + j + 1) / len(train_loader) lr = lr_schedule(epoch_iters) opt.param_groups[0].update(lr=lr) output = model(clamp(X + delta[:X.size(0)], lower_limit, upper_limit)) loss = F.cross_entropy(output, y) opt.zero_grad() with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() grad = delta.grad.detach() delta.data = clamp(delta + epsilon * torch.sign(grad), -epsilon, epsilon) nn.utils.clip_grad_norm_(model.parameters(), 0.5) opt.step() delta.grad.zero_() elif args.attack == 'none': delta = torch.zeros_like(X) output = model(clamp(X + delta[:X.size(0)], lower_limit, upper_limit)) loss = criterion(output, y) if args.attack != 'free': opt.zero_grad() with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 0.5) opt.step() train_loss += loss.item() * y.size(0) train_acc += (output.max(1)[1] == y).sum().item() train_n += y.size(0) if args.overfit_check: # Check current PGD robustness of model using random minibatch X, y = first_batch['input'], first_batch['target'] pgd_delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, opt) with torch.no_grad(): output = model(clamp(X + pgd_delta[:X.size(0)], lower_limit, upper_limit)) robust_acc = (output.max(1)[1] == y).sum().item() / y.size(0) if robust_acc - prev_robust_acc < -0.5: break prev_robust_acc = robust_acc best_state_dict = copy.deepcopy(model.state_dict()) train_time = time.time() logger.info('%d \t %.1f \t %.4f \t %.4f \t %.4f', epoch, train_time - start_time, lr, train_loss/train_n, train_acc/train_n) torch.save(best_state_dict, model_path) logger.info('Total time: %.4f', train_time - start_start_time)