def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    cur_timestamp = str(datetime.now())[:-3]  # we also include ms to prevent the probability of name collision
    model_width = {'linear': '', 'cnn': args.n_filters_cnn, 'lenet': '', 'resnet18': ''}[args.model]
    model_str = '{}{}'.format(args.model, model_width)
    model_name = '{} dataset={} model={} eps={} attack={} m={} attack_init={} fgsm_alpha={} epochs={} pgd={}-{} grad_align_cos_lambda={} lr_max={} seed={}'.format(
        cur_timestamp, args.dataset, model_str, args.eps, args.attack, args.minibatch_replay, args.attack_init, args.fgsm_alpha, args.epochs,
        args.pgd_alpha_train, args.pgd_train_n_iters, args.grad_align_cos_lambda, args.lr_max, args.seed)
    if not os.path.exists('models'):
        os.makedirs('models')
    logger = utils.configure_logger(model_name, args.debug)
    logger.info(args)
    half_prec = args.half_prec
    n_cls = 2 if 'binary' in args.dataset else 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    double_bp = True if args.grad_align_cos_lambda > 0 else False
    n_eval_every_k_iter = args.n_eval_every_k_iter
    args.pgd_alpha = args.eps / 4

    eps, pgd_alpha, pgd_alpha_train = args.eps / 255, args.pgd_alpha / 255, args.pgd_alpha_train / 255
    train_data_augm = False if args.dataset in ['mnist'] else True
    train_batches = data.get_loaders(args.dataset, -1, args.batch_size, train_set=True, shuffle=True, data_augm=train_data_augm)
    train_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size, train_set=True, shuffle=False, data_augm=False)
    test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)
    test_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)

    model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn).cuda()
    model.apply(utils.initialize_weights)
    model.train()

    if args.model == 'resnet18':
        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
    elif args.model == 'cnn':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    elif args.model == 'lenet':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    else:
        raise ValueError('decide about the right optimizer for the new model')

    if half_prec:
        if double_bp:
            amp.register_float_function(torch, 'batch_norm')
        model, opt = amp.initialize(model, opt, opt_level="O1")

    if args.attack == 'fgsm':  # needed here only for Free-AT
        delta = torch.zeros(args.batch_size, *data.shapes_dict[args.dataset][1:]).cuda()
        delta.requires_grad = True

    lr_schedule = utils.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_max)
    loss_function = nn.CrossEntropyLoss()

    train_acc_pgd_best, best_state_dict = 0.0, copy.deepcopy(model.state_dict())
    start_time = time.time()
    time_train, iteration, best_iteration = 0, 0, 0
    for epoch in range(args.epochs + 1):
        train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0
        for i, (X, y) in enumerate(train_batches):
            if i % args.minibatch_replay != 0 and i > 0:  # take new inputs only each `minibatch_replay` iterations
                X, y = X_prev, y_prev
            time_start_iter = time.time()
            # epoch=0 runs only for one iteration (to check the training stats at init)
            if epoch == 0 and i > 0:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches))  # epoch - 1 since the 0th epoch is skipped
            opt.param_groups[0].update(lr=lr)

            if args.attack in ['pgd', 'pgd_corner']:
                pgd_rs = True if args.attack_init == 'random' else False
                n_eps_warmup_epochs = 5
                n_iterations_max_eps = n_eps_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                eps_pgd_train = min(iteration / n_iterations_max_eps * eps, eps) if args.dataset == 'svhn' else eps
                delta = utils.attack_pgd_training(
                    model, X, y, eps_pgd_train, pgd_alpha_train, opt, half_prec, args.pgd_train_n_iters, rs=pgd_rs)
                if args.attack == 'pgd_corner':
                    delta = eps * utils.sign(delta)  # project to the corners
                    delta = clamp(X + delta, 0, 1) - X

            elif args.attack == 'fgsm':
                if args.minibatch_replay == 1:
                    if args.attack_init == 'zero':
                        delta = torch.zeros_like(X, requires_grad=True)
                    elif args.attack_init == 'random':
                        delta = utils.get_uniform_delta(X.shape, eps, requires_grad=True)
                    else:
                        raise ValueError('wrong args.attack_init')
                else:  # if Free-AT, we just reuse the existing delta from the previous iteration
                    delta.requires_grad = True

                X_adv = clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                if half_prec:
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        grad = torch.autograd.grad(scaled_loss, delta, create_graph=True if double_bp else False)[0]
                        grad /= scaled_loss / loss  # reverse back the scaling
                else:
                    grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]

                grad = grad.detach()

                argmax_delta = eps * utils.sign(grad)

                n_alpha_warmup_epochs = 5
                n_iterations_max_alpha = n_alpha_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                fgsm_alpha = min(iteration / n_iterations_max_alpha * args.fgsm_alpha, args.fgsm_alpha) if args.dataset == 'svhn' else args.fgsm_alpha
                delta.data = clamp(delta.data + fgsm_alpha * argmax_delta, -eps, eps)
                delta.data = clamp(X + delta.data, 0, 1) - X

            elif args.attack == 'random_corner':
                delta = utils.get_uniform_delta(X.shape, eps, requires_grad=False)
                delta = eps * utils.sign(delta)

            elif args.attack == 'none':
                delta = torch.zeros_like(X, requires_grad=False)
            else:
                raise ValueError('wrong args.attack')

            # extra FP+BP to calculate the gradient to monitor it
            if args.attack in ['none', 'random_corner', 'pgd', 'pgd_corner']:
                grad = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='none',
                                      backprop=args.grad_align_cos_lambda != 0.0)

            delta = delta.detach()

            output = model(X + delta)
            loss = loss_function(output, y)

            reg = torch.zeros(1).cuda()[0]  # for .item() to run correctly
            if args.grad_align_cos_lambda != 0.0:
                grad2 = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='random_uniform', backprop=True)
                grads_nnz_idx = ((grad**2).sum([1, 2, 3])**0.5 != 0) * ((grad2**2).sum([1, 2, 3])**0.5 != 0)
                grad1, grad2 = grad[grads_nnz_idx], grad2[grads_nnz_idx]
                grad1_norms, grad2_norms = l2_norm_batch(grad1), l2_norm_batch(grad2)
                grad1_normalized = grad1 / grad1_norms[:, None, None, None]
                grad2_normalized = grad2 / grad2_norms[:, None, None, None]
                cos = torch.sum(grad1_normalized * grad2_normalized, (1, 2, 3))
                reg += args.grad_align_cos_lambda * (1.0 - cos.mean())

            loss += reg

            if epoch != 0:
                opt.zero_grad()
                utils.backward(loss, opt, half_prec)
                opt.step()

            time_train += time.time() - time_start_iter
            train_loss += loss.item() * y.size(0)
            train_reg += reg.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

            with torch.no_grad():  # no grad for the stats
                grad_norm_x += l2_norm_batch(grad).sum().item()
                delta_final = clamp(X + delta, 0, 1) - X  # we should measure delta after the projection onto [0, 1]^d
                avg_delta_l2 += ((delta_final ** 2).sum([1, 2, 3]) ** 0.5).sum().item()

            if iteration % args.eval_iter_freq == 0:
                train_loss, train_reg = train_loss / train_n, train_reg / train_n
                train_acc, avg_delta_l2 = train_acc / train_n, avg_delta_l2 / train_n

                # it'd be incorrect to recalculate the BN stats on the test sets and for clean / adversarial points
                utils.model_eval(model, half_prec)

                test_acc_clean, _, _ = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_fgsm, test_loss_fgsm, fgsm_deltas = rob_acc(test_batches_fast, model, eps, eps, opt, half_prec, 1, 1, rs=False)
                test_acc_pgd, test_loss_pgd, pgd_deltas = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)
                cos_fgsm_pgd = utils.avg_cos_np(fgsm_deltas, pgd_deltas)
                train_acc_pgd, _, _ = rob_acc(train_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)  # needed for early stopping

                grad_x = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=False)
                grad_eta = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=True)
                cos_x_eta = utils.avg_cos_np(grad_x, grad_eta)

                time_elapsed = time.time() - start_time
                train_str = '[train] loss {:.3f}, reg {:.3f}, acc {:.2%} acc_pgd {:.2%}'.format(train_loss, train_reg, train_acc, train_acc_pgd)
                test_str = '[test] acc_clean {:.2%}, acc_fgsm {:.2%}, acc_pgd {:.2%}, cos_x_eta {:.3}, cos_fgsm_pgd {:.3}'.format(
                    test_acc_clean, test_acc_fgsm, test_acc_pgd, cos_x_eta, cos_fgsm_pgd)
                logger.info('{}-{}: {}  {} ({:.2f}m, {:.2f}m)'.format(epoch, iteration, train_str, test_str,
                                                                      time_train/60, time_elapsed/60))

                if train_acc_pgd > train_acc_pgd_best:  # catastrophic overfitting can be detected on the training set
                    best_state_dict = copy.deepcopy(model.state_dict())
                    train_acc_pgd_best, best_iteration = train_acc_pgd, iteration

                utils.model_train(model, half_prec)
                train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0

            iteration += 1
            X_prev, y_prev = X.clone(), y.clone()  # needed for Free-AT

        if epoch == args.epochs:
            torch.save({'last': model.state_dict(), 'best': best_state_dict}, 'models/{} epoch={}.pth'.format(model_name, epoch))
            # disable global conversion to fp16 from amp.initialize() (https://github.com/NVIDIA/apex/issues/567)
            context_manager = amp.disable_casts() if half_prec else utils.nullcontext()
            with context_manager:
                last_state_dict = copy.deepcopy(model.state_dict())
                half_prec = False  # final eval is always in fp32
                model.load_state_dict(last_state_dict)
                utils.model_eval(model, half_prec)
                opt = torch.optim.SGD(model.parameters(), lr=0)

                attack_iters, n_restarts = (50, 10) if not args.debug else (10, 3)
                test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                logger.info('[last: test on 10k points] acc_clean {:.2%}, pgd_rr {:.2%}'.format(test_acc_clean, test_acc_pgd_rr))

                if args.eval_early_stopped_model:
                    model.load_state_dict(best_state_dict)
                    utils.model_eval(model, half_prec)
                    test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                    test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                    logger.info('[best: test on 10k points][iter={}] acc_clean {:.2%}, pgd_rr {:.2%}'.format(
                        best_iteration, test_acc_clean, test_acc_pgd_rr))

        utils.model_train(model, half_prec)

    logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
示例#2
0
                                metrics = { embedding_layer_name : metric, 'prob' : 'accuracy' })
            else:
                par_model.compile(optimizer = keras.optimizers.SGD(lr=args.sgd_lr, momentum=0.9, nesterov=args.nesterov, clipnorm = args.clipgrad),
                                loss = loss,
                                metrics = [metric])
            par_model.fit_generator(
                    data_generator.train_sequence(args.batch_size, batch_transform = transform_inputs, batch_transform_kwargs = batch_transform_kwargs),
                    validation_data = data_generator.test_sequence(args.val_batch_size, batch_transform = transform_inputs, batch_transform_kwargs = batch_transform_kwargs),
                    epochs = args.finetune_init, verbose = not args.no_progress,
                    max_queue_size = args.queue_size, workers = args.read_workers, use_multiprocessing = True)
            for layer in model.layers:
                layer.trainable = True
            print('Full model training')

    # Train model
    callbacks, num_epochs = utils.get_lr_schedule(args.lr_schedule, data_generator.num_train, args.batch_size, schedule_args = { arg_name : arg_val for arg_name, arg_val in vars(args).items() if arg_val is not None })

    if args.log_dir:
        if os.path.isdir(args.log_dir):
            shutil.rmtree(args.log_dir, ignore_errors = True)
        callbacks.append(keras.callbacks.TensorBoard(log_dir = args.log_dir, write_graph = False))
    
    if args.snapshot:
        snapshot_kwargs = {}
        if args.snapshot_best:
            snapshot_kwargs['save_best_only'] = True
            snapshot_kwargs['monitor'] = args.snapshot_best
        callbacks.append(keras.callbacks.ModelCheckpoint(args.snapshot, **snapshot_kwargs) if args.gpus <= 1 else utils.TemplateModelCheckpoint(model, args.snapshot, **snapshot_kwargs))

    if args.max_decay > 0:
        decay = (1.0/args.max_decay - 1) / ((data_generator.num_train // args.batch_size) * (args.epochs if args.epochs else num_epochs))