Ejemplo n.º 1
0
def main():
    model = StandardModel(args.dataset,
                          args.arch,
                          no_grad=False,
                          load_pretrained=False)
    model.cuda()
    model.train()
    device = torch.device("cuda")
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    model_path = '{}/train_pytorch_model/adversarial_train/TRADES/{}@{}@epoch_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.epochs, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("After trained, the model will save to {}".format(model_path))
    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, device, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print(
            '================================================================')
        eval_train(model, device, train_loader)
        eval_test(model, device, test_loader)
        print(
            '================================================================')

        # save checkpoint
        if epoch % args.save_freq == 0:
            state = {
                'state_dict': model.state_dict(),
                'epoch': epoch,
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, os.path.join(model_dir, model_path))
Ejemplo n.º 2
0
    def __init__(self, dataset, batch_size, meta_arch, meta_train_type, distill_loss, data_loss, norm, targeted, use_softmax, mode="meta"):
        if mode == "meta":
            target_str = "targeted_attack_random" if targeted else "untargeted_attack"
            # 2Q_DISTILLATION@CIFAR-100@TRAIN_I_TEST_II@model_resnet34@loss_pair_mse@dataloss_cw_l2_untargeted_attack@epoch_4@meta_batch_size_30@num_support_50@num_updates_12@lr_0.001@inner_lr_0.01.pth.tar
            self.meta_model_path = "{root}/train_pytorch_model/meta_simulator/{meta_train_type}@{dataset}@{split}@model_{meta_arch}@loss_{loss}@dataloss_{data_loss}_{norm}_{target_str}*inner_lr_0.01.pth.tar".format(
                root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset, split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II,
                meta_arch=meta_arch, loss=distill_loss, data_loss=data_loss, norm=norm, target_str=target_str)
            log.info("start using {}".format(self.meta_model_path))
            self.meta_model_path = glob.glob(self.meta_model_path)
            pattern = re.compile(".*model_(.*?)@.*inner_lr_(.*?)\.pth.*")
            assert len(self.meta_model_path) > 0
            self.meta_model_path = self.meta_model_path[0]
            log.info("load meta model {}".format(self.meta_model_path))
            ma = pattern.match(os.path.basename(self.meta_model_path))
            arch = ma.group(1)
            self.inner_lr = float(ma.group(2))
            meta_backbone = self.construct_model(arch, dataset)
            self.meta_network = MetaNetwork(meta_backbone)
            self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)
            self.meta_network.load_state_dict(self.pretrained_weights["state_dict"])
            log.info("Load model in epoch {}.".format(self.pretrained_weights["epoch"]))
            self.pretrained_weights = self.pretrained_weights["state_dict"]
        elif mode == "vanilla":
            target_str = "targeted" if targeted else "untargeted"
            arch = meta_arch
            # 2Q_DISTILLATION@CIFAR-100@TRAIN_I_TEST_II@model_resnet34@loss_pair_mse@dataloss_cw_l2_untargeted_attack@epoch_4@meta_batch_size_30@num_support_50@num_updates_12@lr_0.001@inner_lr_0.01.pth.tar
            self.meta_model_path = "{root}/train_pytorch_model/vanilla_simulator/{dataset}@{norm}_norm_{target_str}@{meta_arch}*.tar".format(
                root=PY_ROOT, dataset=dataset,
                meta_arch=meta_arch,norm=norm, target_str=target_str)
            log.info("start using {}".format(self.meta_model_path))
            self.meta_model_path = glob.glob(self.meta_model_path)
            assert len(self.meta_model_path) > 0
            self.meta_model_path = self.meta_model_path[0]
            log.info("load meta model {}".format(self.meta_model_path))
            self.inner_lr = 0.01
            self.meta_network = self.construct_model(meta_arch, dataset)
            self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)
            log.info("Load model in epoch {}.".format(self.pretrained_weights["epoch"]))
            self.pretrained_weights = self.pretrained_weights["state_dict"]
        elif mode == "deep_benign_images":
            arch = "resnet34"
            self.inner_lr = 0.01
            self.meta_network = self.construct_model(arch, dataset)
            self.meta_model_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}@epoch_200@lr_0.1@batch_200.pth.tar".format(
                root=PY_ROOT, dataset=dataset, arch=arch)
            assert os.path.exists(self.meta_model_path), "{} does not exists!".format(self.meta_model_path)
            self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)[
                "state_dict"]
        elif mode == "random_init":
            arch = "resnet34"
            self.inner_lr = 0.01
            self.meta_network = self.construct_model(arch, dataset)
            self.pretrained_weights = self.meta_network.state_dict()
        elif mode == 'ensemble_avg':
            self.inner_lr = 0.01
            self.archs = ["densenet-bc-100-12","resnet-110","vgg19_bn"]
            self.meta_network = []  # meta_network和pretrained_weights都改成list
            self.pretrained_weights = []
            for arch in self.archs:
                model = StandardModel(dataset, arch, no_grad=False, load_pretrained=True)
                model.eval()
                model.cuda()
                self.meta_network.append(model)
                self.pretrained_weights.append(model.state_dict())
        elif mode == "benign_images":
            self.inner_lr = 0.01
            self.meta_model_path = "{root}/train_pytorch_model/meta_simulator_on_benign_images/{dataset}@{split}*@inner_lr_0.01.pth.tar".format(
                root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset,
                split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II)
            self.meta_model_path = glob.glob(self.meta_model_path)
            pattern = re.compile(".*model_(.*?)@.*")
            assert len(self.meta_model_path) > 0
            self.meta_model_path = self.meta_model_path[0]
            ma = pattern.match(os.path.basename(self.meta_model_path))
            log.info("Loading meta model from {}".format(self.meta_model_path))
            arch = ma.group(1)
            self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)["state_dict"]
            meta_backbone = self.construct_model(arch, dataset)
            self.meta_network = MetaNetwork(meta_backbone)
            self.meta_network.load_state_dict(self.pretrained_weights)
            self.meta_network.eval()
            self.meta_network.cuda()
        elif mode == "reptile_on_benign_images":
            self.inner_lr = 0.01
            self.meta_model_path = "{root}/train_pytorch_model/meta_simulator_reptile_on_benign_images/{dataset}@{split}*@inner_lr_0.01.pth.tar".format(
                root=PY_ROOT, meta_train_type=meta_train_type.upper(), dataset=dataset,
                split=SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II)
            self.meta_model_path = glob.glob(self.meta_model_path)
            pattern = re.compile(".*model_(.*?)@.*")
            assert len(self.meta_model_path) > 0
            self.meta_model_path = self.meta_model_path[0]
            log.info("Loading meta model from {}".format(self.meta_model_path))
            ma = pattern.match(os.path.basename(self.meta_model_path))
            arch = ma.group(1)
            self.pretrained_weights = torch.load(self.meta_model_path, map_location=lambda storage, location: storage)["state_dict"]
            meta_backbone = self.construct_model(arch, dataset)
            self.meta_network = MetaNetwork(meta_backbone)
            self.meta_network.load_state_dict(self.pretrained_weights)
            self.meta_network.eval()
            self.meta_network.cuda()


        self.arch = arch
        self.dataset = dataset
        self.need_pair_distance = (distill_loss.lower()=="pair_mse")
        # self.need_pair_distance = False
        self.softmax = nn.Softmax(dim=1)
        self.mse_loss = nn.MSELoss(reduction="mean")
        self.pair_wise_distance = nn.PairwiseDistance(p=2)
        self.use_softmax = use_softmax
        if mode != "ensemble_avg":
            self.meta_network.load_state_dict(self.pretrained_weights)
            self.meta_network.eval()
            self.meta_network.cuda()
        self.batch_size = batch_size
        if mode == 'ensemble_avg':
            self.batch_weights = defaultdict(dict)
            for idx in range(len(self.pretrained_weights)):
                for i in range(batch_size):
                    self.batch_weights[idx][i] = self.pretrained_weights[idx]
        else:
            self.batch_weights = dict()
            for i in range(batch_size):
                self.batch_weights[i] = self.pretrained_weights
Ejemplo n.º 3
0
        model.zero_grad()
        loss.backward()
        optimizer.step()

    logger.info('[%d] train loss: adv: %.3f, clean: %.3f' %
                (epoch + 1, running_loss_1 / i, running_loss_2 / i))

    if epoch % EVALUATE_EPOCH == 0:
        running_loss, correct, total = 0.0, 0.0, 0.0
        model.eval()
        for i, data_batch in enumerate(val_loader):
            # get the inputs; data is a list of [inputs, labels]
            img_batch, label_batch = data_batch
            img_batch, label_batch = img_batch.cuda(), label_batch.cuda()
            output_batch = model(img_batch)
            loss = criterion(output_batch, label_batch)
            running_loss += loss.item()

            _, predicted = torch.max(output_batch.data, 1)
            _, label_ind = torch.max(label_batch.data, 1)
            correct += (predicted == label_ind).sum().item()
            total += label_batch.size(0)
        logger.info('[%d] test loss: %.3f, accuracy: %.3f' %
                    (epoch + 1, running_loss / i, correct / total))

    if epoch % args.save_epoch == 0 or epoch == EPOCH_TOTAL - 1:
        torch.save(model.state_dict(),
                   os.path.join(MODELS_FOLDER, "eopch{}.ckpt".format(epoch)))

logger.info('Finished Training')
Ejemplo n.º 4
0
def main():
    args = get_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    logger.info(args)
    model_path = '{}/train_pytorch_model/adversarial_train/fast_adv_train/{}@{}@epoch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.epochs)
    out_dir = os.path.dirname(model_path)
    os.makedirs(out_dir, exist_ok=True)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    start_start_time = time.time()
    train_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, True)
    test_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, False)

    epsilon = (args.epsilon / 255.) / std
    pgd_alpha = (args.pgd_alpha / 255.) / std

    model = StandardModel(args.dataset, args.arch, no_grad=False)
    model.apply(initialize_weights)
    model.cuda()
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=5e-4)

    model, opt = amp.initialize(model, opt, opt_level="O2", loss_scale=1.0, master_weights=False)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        assert args.epochs % args.attack_iters == 0
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    prev_robust_acc = 0.
    logger.info('Epoch \t Time \t LR \t \t Train Loss \t Train Acc')
    for epoch in range(epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            X = X.cuda().float()
            y = y.cuda().long()
            if i == 0:
                first_batch = X, y
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, opt)

            elif args.attack == 'fgsm':
                if args.fgsm_init == 'zero':
                    delta = torch.zeros_like(X, requires_grad=True)
                    delta.requires_grad = True
                elif args.fgsm_init == 'random':
                    delta = torch.zeros_like(X).cuda()
                    delta[:, 0, :, :].uniform_(-epsilon[0][0][0].item(), epsilon[0][0][0].item())
                    delta[:, 1, :, :].uniform_(-epsilon[1][0][0].item(), epsilon[1][0][0].item())
                    delta[:, 2, :, :].uniform_(-epsilon[2][0][0].item(), epsilon[2][0][0].item())
                    delta.requires_grad = True
                elif args.fgsm_init == 'previous':
                    delta.requires_grad = True
                output = model(X + delta[:X.size(0)])
                loss = F.cross_entropy(output, y)
                with amp.scale_loss(loss, opt) as scaled_loss:
                    scaled_loss.backward()
                grad = delta.grad.detach()
                delta.data = clamp(delta + args.fgsm_alpha * epsilon * torch.sign(grad), -epsilon, epsilon)
                delta = delta.detach()

            elif args.attack == 'free':
                delta.requires_grad = True
                for j in range(args.attack_iters):
                    epoch_iters = epoch * args.attack_iters + (i * args.attack_iters + j + 1) / len(train_loader)
                    lr = lr_schedule(epoch_iters)
                    opt.param_groups[0].update(lr=lr)
                    output = model(clamp(X + delta[:X.size(0)], lower_limit, upper_limit))
                    loss = F.cross_entropy(output, y)
                    opt.zero_grad()
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        scaled_loss.backward()
                    grad = delta.grad.detach()
                    delta.data = clamp(delta + epsilon * torch.sign(grad), -epsilon, epsilon)
                    nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    opt.step()
                    delta.grad.zero_()
            elif args.attack == 'none':
                delta = torch.zeros_like(X)

            output = model(clamp(X + delta[:X.size(0)], lower_limit, upper_limit))
            loss = criterion(output, y)
            if args.attack != 'free':
                opt.zero_grad()
                with amp.scale_loss(loss, opt) as scaled_loss:
                    scaled_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                opt.step()

            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        if args.overfit_check:
            # Check current PGD robustness of model using random minibatch
            X, y = first_batch['input'], first_batch['target']
            pgd_delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, opt)
            with torch.no_grad():
                output = model(clamp(X + pgd_delta[:X.size(0)], lower_limit, upper_limit))
            robust_acc = (output.max(1)[1] == y).sum().item() / y.size(0)
            if robust_acc - prev_robust_acc < -0.5:
                break
            prev_robust_acc = robust_acc
        best_state_dict = copy.deepcopy(model.state_dict())

        train_time = time.time()
        logger.info('%d \t %.1f \t %.4f \t %.4f \t %.4f',
            epoch, train_time - start_time, lr, train_loss/train_n, train_acc/train_n)
    torch.save(best_state_dict, model_path)
    logger.info('Total time: %.4f', train_time - start_start_time)