Exemplo n.º 1
0
def main():
    # check the configurations
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # prepare data for training
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='../data',
                                            train=True,
                                            download=True,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              shuffle=True,
                                              batch_size=128,
                                              num_workers=4,
                                              pin_memory=True)

    testset = torchvision.datasets.CIFAR10(root='../data',
                                           train=False,
                                           download=True,
                                           transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             shuffle=False,
                                             batch_size=128,
                                             num_workers=4,
                                             pin_memory=True)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # initilizae the model
    net = VGG().cuda() if use_cuda else VGG()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.05,
                          momentum=0.9,
                          weight_decay=5e-4)

    # training or loading the neural network
    # model_train(net, trainloader, criterion, optimizer, epochs=5)
    net.load_state_dict(
        torch.load('resources/vgg16_cifar10.bin', map_location='cpu'))
    print('Neural network ready.')

    # evaluate the model performance
    accuracy, _ = model_eval(net, testloader, criterion)
    print('Accuracy of the network on the clean test images: %d %%' %
          (100 * accuracy))

    accuracy, _ = model_eval(net,
                             testloader,
                             criterion,
                             attack_method=illcm_attack)
    print('Accuracy of the network on the adversarial test images: %d %%' %
          (100 * accuracy))
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    cur_timestamp = str(datetime.now())[:-3]  # we also include ms to prevent the probability of name collision
    model_width = {'linear': '', 'cnn': args.n_filters_cnn, 'lenet': '', 'resnet18': ''}[args.model]
    model_str = '{}{}'.format(args.model, model_width)
    model_name = '{} dataset={} model={} eps={} attack={} m={} attack_init={} fgsm_alpha={} epochs={} pgd={}-{} grad_align_cos_lambda={} lr_max={} seed={}'.format(
        cur_timestamp, args.dataset, model_str, args.eps, args.attack, args.minibatch_replay, args.attack_init, args.fgsm_alpha, args.epochs,
        args.pgd_alpha_train, args.pgd_train_n_iters, args.grad_align_cos_lambda, args.lr_max, args.seed)
    if not os.path.exists('models'):
        os.makedirs('models')
    logger = utils.configure_logger(model_name, args.debug)
    logger.info(args)
    half_prec = args.half_prec
    n_cls = 2 if 'binary' in args.dataset else 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    double_bp = True if args.grad_align_cos_lambda > 0 else False
    n_eval_every_k_iter = args.n_eval_every_k_iter
    args.pgd_alpha = args.eps / 4

    eps, pgd_alpha, pgd_alpha_train = args.eps / 255, args.pgd_alpha / 255, args.pgd_alpha_train / 255
    train_data_augm = False if args.dataset in ['mnist'] else True
    train_batches = data.get_loaders(args.dataset, -1, args.batch_size, train_set=True, shuffle=True, data_augm=train_data_augm)
    train_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size, train_set=True, shuffle=False, data_augm=False)
    test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)
    test_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)

    model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn).cuda()
    model.apply(utils.initialize_weights)
    model.train()

    if args.model == 'resnet18':
        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
    elif args.model == 'cnn':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    elif args.model == 'lenet':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    else:
        raise ValueError('decide about the right optimizer for the new model')

    if half_prec:
        if double_bp:
            amp.register_float_function(torch, 'batch_norm')
        model, opt = amp.initialize(model, opt, opt_level="O1")

    if args.attack == 'fgsm':  # needed here only for Free-AT
        delta = torch.zeros(args.batch_size, *data.shapes_dict[args.dataset][1:]).cuda()
        delta.requires_grad = True

    lr_schedule = utils.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_max)
    loss_function = nn.CrossEntropyLoss()

    train_acc_pgd_best, best_state_dict = 0.0, copy.deepcopy(model.state_dict())
    start_time = time.time()
    time_train, iteration, best_iteration = 0, 0, 0
    for epoch in range(args.epochs + 1):
        train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0
        for i, (X, y) in enumerate(train_batches):
            if i % args.minibatch_replay != 0 and i > 0:  # take new inputs only each `minibatch_replay` iterations
                X, y = X_prev, y_prev
            time_start_iter = time.time()
            # epoch=0 runs only for one iteration (to check the training stats at init)
            if epoch == 0 and i > 0:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches))  # epoch - 1 since the 0th epoch is skipped
            opt.param_groups[0].update(lr=lr)

            if args.attack in ['pgd', 'pgd_corner']:
                pgd_rs = True if args.attack_init == 'random' else False
                n_eps_warmup_epochs = 5
                n_iterations_max_eps = n_eps_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                eps_pgd_train = min(iteration / n_iterations_max_eps * eps, eps) if args.dataset == 'svhn' else eps
                delta = utils.attack_pgd_training(
                    model, X, y, eps_pgd_train, pgd_alpha_train, opt, half_prec, args.pgd_train_n_iters, rs=pgd_rs)
                if args.attack == 'pgd_corner':
                    delta = eps * utils.sign(delta)  # project to the corners
                    delta = clamp(X + delta, 0, 1) - X

            elif args.attack == 'fgsm':
                if args.minibatch_replay == 1:
                    if args.attack_init == 'zero':
                        delta = torch.zeros_like(X, requires_grad=True)
                    elif args.attack_init == 'random':
                        delta = utils.get_uniform_delta(X.shape, eps, requires_grad=True)
                    else:
                        raise ValueError('wrong args.attack_init')
                else:  # if Free-AT, we just reuse the existing delta from the previous iteration
                    delta.requires_grad = True

                X_adv = clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                if half_prec:
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        grad = torch.autograd.grad(scaled_loss, delta, create_graph=True if double_bp else False)[0]
                        grad /= scaled_loss / loss  # reverse back the scaling
                else:
                    grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]

                grad = grad.detach()

                argmax_delta = eps * utils.sign(grad)

                n_alpha_warmup_epochs = 5
                n_iterations_max_alpha = n_alpha_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                fgsm_alpha = min(iteration / n_iterations_max_alpha * args.fgsm_alpha, args.fgsm_alpha) if args.dataset == 'svhn' else args.fgsm_alpha
                delta.data = clamp(delta.data + fgsm_alpha * argmax_delta, -eps, eps)
                delta.data = clamp(X + delta.data, 0, 1) - X

            elif args.attack == 'random_corner':
                delta = utils.get_uniform_delta(X.shape, eps, requires_grad=False)
                delta = eps * utils.sign(delta)

            elif args.attack == 'none':
                delta = torch.zeros_like(X, requires_grad=False)
            else:
                raise ValueError('wrong args.attack')

            # extra FP+BP to calculate the gradient to monitor it
            if args.attack in ['none', 'random_corner', 'pgd', 'pgd_corner']:
                grad = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='none',
                                      backprop=args.grad_align_cos_lambda != 0.0)

            delta = delta.detach()

            output = model(X + delta)
            loss = loss_function(output, y)

            reg = torch.zeros(1).cuda()[0]  # for .item() to run correctly
            if args.grad_align_cos_lambda != 0.0:
                grad2 = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='random_uniform', backprop=True)
                grads_nnz_idx = ((grad**2).sum([1, 2, 3])**0.5 != 0) * ((grad2**2).sum([1, 2, 3])**0.5 != 0)
                grad1, grad2 = grad[grads_nnz_idx], grad2[grads_nnz_idx]
                grad1_norms, grad2_norms = l2_norm_batch(grad1), l2_norm_batch(grad2)
                grad1_normalized = grad1 / grad1_norms[:, None, None, None]
                grad2_normalized = grad2 / grad2_norms[:, None, None, None]
                cos = torch.sum(grad1_normalized * grad2_normalized, (1, 2, 3))
                reg += args.grad_align_cos_lambda * (1.0 - cos.mean())

            loss += reg

            if epoch != 0:
                opt.zero_grad()
                utils.backward(loss, opt, half_prec)
                opt.step()

            time_train += time.time() - time_start_iter
            train_loss += loss.item() * y.size(0)
            train_reg += reg.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

            with torch.no_grad():  # no grad for the stats
                grad_norm_x += l2_norm_batch(grad).sum().item()
                delta_final = clamp(X + delta, 0, 1) - X  # we should measure delta after the projection onto [0, 1]^d
                avg_delta_l2 += ((delta_final ** 2).sum([1, 2, 3]) ** 0.5).sum().item()

            if iteration % args.eval_iter_freq == 0:
                train_loss, train_reg = train_loss / train_n, train_reg / train_n
                train_acc, avg_delta_l2 = train_acc / train_n, avg_delta_l2 / train_n

                # it'd be incorrect to recalculate the BN stats on the test sets and for clean / adversarial points
                utils.model_eval(model, half_prec)

                test_acc_clean, _, _ = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_fgsm, test_loss_fgsm, fgsm_deltas = rob_acc(test_batches_fast, model, eps, eps, opt, half_prec, 1, 1, rs=False)
                test_acc_pgd, test_loss_pgd, pgd_deltas = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)
                cos_fgsm_pgd = utils.avg_cos_np(fgsm_deltas, pgd_deltas)
                train_acc_pgd, _, _ = rob_acc(train_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)  # needed for early stopping

                grad_x = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=False)
                grad_eta = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=True)
                cos_x_eta = utils.avg_cos_np(grad_x, grad_eta)

                time_elapsed = time.time() - start_time
                train_str = '[train] loss {:.3f}, reg {:.3f}, acc {:.2%} acc_pgd {:.2%}'.format(train_loss, train_reg, train_acc, train_acc_pgd)
                test_str = '[test] acc_clean {:.2%}, acc_fgsm {:.2%}, acc_pgd {:.2%}, cos_x_eta {:.3}, cos_fgsm_pgd {:.3}'.format(
                    test_acc_clean, test_acc_fgsm, test_acc_pgd, cos_x_eta, cos_fgsm_pgd)
                logger.info('{}-{}: {}  {} ({:.2f}m, {:.2f}m)'.format(epoch, iteration, train_str, test_str,
                                                                      time_train/60, time_elapsed/60))

                if train_acc_pgd > train_acc_pgd_best:  # catastrophic overfitting can be detected on the training set
                    best_state_dict = copy.deepcopy(model.state_dict())
                    train_acc_pgd_best, best_iteration = train_acc_pgd, iteration

                utils.model_train(model, half_prec)
                train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0

            iteration += 1
            X_prev, y_prev = X.clone(), y.clone()  # needed for Free-AT

        if epoch == args.epochs:
            torch.save({'last': model.state_dict(), 'best': best_state_dict}, 'models/{} epoch={}.pth'.format(model_name, epoch))
            # disable global conversion to fp16 from amp.initialize() (https://github.com/NVIDIA/apex/issues/567)
            context_manager = amp.disable_casts() if half_prec else utils.nullcontext()
            with context_manager:
                last_state_dict = copy.deepcopy(model.state_dict())
                half_prec = False  # final eval is always in fp32
                model.load_state_dict(last_state_dict)
                utils.model_eval(model, half_prec)
                opt = torch.optim.SGD(model.parameters(), lr=0)

                attack_iters, n_restarts = (50, 10) if not args.debug else (10, 3)
                test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                logger.info('[last: test on 10k points] acc_clean {:.2%}, pgd_rr {:.2%}'.format(test_acc_clean, test_acc_pgd_rr))

                if args.eval_early_stopped_model:
                    model.load_state_dict(best_state_dict)
                    utils.model_eval(model, half_prec)
                    test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                    test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                    logger.info('[best: test on 10k points][iter={}] acc_clean {:.2%}, pgd_rr {:.2%}'.format(
                        best_iteration, test_acc_clean, test_acc_pgd_rr))

        utils.model_train(model, half_prec)

    logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
Exemplo n.º 3
0
                  (epoch + n * params['maxEpoch'],
                   datetime.datetime.now().strftime('%X')))
            ser_MMNet, loss = model_loss(sess, MMNet_nodes, data_Feed)
            print('ser_MMNet', ser_MMNet, 'loss=', loss, '\n')
            loss_MMNet.append(loss)
loss_all['MMNet'] = loss_MMNet
train_ed = time.time()
print("MMNet Train time is: " + str(train_ed - train_st))

# Testing
results = {
    'SNR_dBs':
    np.arange(params['SNR_dB_min_test'], params['SNR_dB_max_test'],
              params['SNR_step_test'])
}
ser_MMNets = model_eval(sess, params, MMNet_nodes, gen_data,
                        params['test_iterations'])
results['ser_MMNets'] = ser_MMNets
print(ser_MMNets)

# plot
for key, value in results.items():
    if key == 'SNR_dBs':
        pass
    else:
        print(key, value)
        plt.plot(results['SNR_dBs'], value, label=key)
plt.grid(True, which='minor', linestyle='--')
plt.yscale('log')
plt.xlabel('SNR')
plt.ylabel('SER')
plt.title('Nr%dNt%d_mod%s' % (params['Nr'], params['Nt'], params['mod_name']))
            print('===========epoch%d,now_time is %s=========' %
                  (epoch + n * params['maxEpoch'],
                   datetime.datetime.now().strftime('%X')))
            ser, loss = model_loss(sess, DetNetSIC2_nodes, data_Feed)
            print('ser', ser, 'loss=', loss, '\n')
            loss_all.append(loss)
train_ed = time.time()
print("Train time is: " + str(train_ed - train_st))

# Testing
results = {
    'SNR_dBs':
    np.arange(params['SNR_dB_min_test'], params['SNR_dB_max_test'],
              params['SNR_step_test'])
}
sers = model_eval(sess, params, DetNetSIC2_nodes, gen_data,
                  params['test_iterations'])
results['ser_DetNetSIC2s'] = sers
print(sers)

# plot
for key, value in results.items():
    if key == 'SNR_dBs':
        pass
    else:
        print(key, value)
        plt.plot(results['SNR_dBs'], value, label=key)
plt.grid(True, which='minor', linestyle='--')
plt.yscale('log')
plt.xlabel('SNR')
plt.ylabel('SER')
plt.title('Nr%dNt%d_mod%s' % (params['Nr'], params['Nt'], params['mod_name']))
Exemplo n.º 5
0
def train(model, X_train_synth, y_train_synth, X_val_sx, y_val_sx):
    """
        train and cross validate model
    """

    # prepare sx3 dataset
    dataset_nomp1 = SX3Dataset(label=0,
                               global_path=data_path +
                               'sx_data/snap_no_mp_SX3_5_sat_11_89x81')
    dataset_nomp2 = SX3Dataset(label=0,
                               global_path=data_path +
                               'sx_data/snap_no_mp_SX3_5_sat_18_89x81')
    dataset_mp1 = SX3Dataset(label=1,
                             global_path=data_path +
                             'sx_data/snap_mp_SX3_5_sat_11_89x81')
    dataset_mp2 = SX3Dataset(label=1,
                             global_path=data_path +
                             'sx_data/snap_mp_SX3_5_sat_18_89x81')
    data_nomp1 = dataset_nomp1.build(discr_shape=(80, 80), nb_samples=100)
    data_nomp2 = dataset_nomp2.build(discr_shape=(80, 80), nb_samples=100)
    data_mp1 = dataset_mp1.build(discr_shape=(80, 80), nb_samples=100)
    data_mp2 = dataset_mp2.build(discr_shape=(80, 80), nb_samples=100)

    data_val = np.concatenate((data_mp1, data_mp2, data_nomp1, data_nomp2),
                              axis=0)
    np.random.shuffle(data_val)

    X_val_sx = np.array([x['table'] for x in data_val])
    y_val_sx = np.array([x['label'] for x in data_val])

    # prepare data generator data
    discr = 80
    X_train_synth, X_val_synth, y_train_synth, y_val_synth = load_ds_data(
        discr, data_path)

    # scale dataset
    X_train_synth = (X_train_synth -
                     X_train_synth.mean()) / X_train_synth.std()
    X_val_sx = (X_val_sx - X_val_sx.mean()) / X_val_sx.std()

    # define model params
    learning_rate = 1e-4
    optimizer = optimizers.Adam(lr=learning_rate)
    batch_size = 8
    train_iters = 20

    attn_model = model
    datagen = ImageDataGenerator(
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.01,
        zoom_range=[0.9, 1.25],
        fill_mode='nearest',
        #zca_whitening=True,
        channel_shift_range=0.9,
        #brightness_range=[0.5,1.5]
    )
    datagen.fit(X_train_synth)

    attn_model.compile(
        loss='binary_crossentropy',
        optimizer=optimizer,
        metrics=['acc',
                 keras_metrics.precision(),
                 keras_metrics.recall()])

    print(attn_model.summary())

    model_name = 'attn_model_dense_st-sc'

    checkpointer = ModelCheckpoint(filepath='{}.h5'.format(model_name),
                                   monitor='val_acc',
                                   verbose=1,
                                   save_best_only=True)
    #reduce_lr = LearningRateScheduler(lr_scheduler, verbose=1)

    history = attn_model.fit_generator(
        datagen.flow(X_train_synth, y_train_synth, batch_size=batch_size),
        validation_data=(X_val_sx, y_val_sx),
        epochs=train_iters,
        callbacks=[checkpointer]  #, reduce_lr, tensorboard_callback]
    )

    attn_model.load_weights('{}.h5'.format(model_name))
    model_eval(attn_model, X_val_sx, y_val_sx, model_name, 0.5)

    return model
Exemplo n.º 6
0
    loss_all['DetNetDemod'] = loss_DetNetDemod  # 损失值
    train_ed = time.time()
    print("DetNetDemod Train time is: " + str(train_ed - train_st))

# Testing
results = {
    'SNR_dBs':
    np.arange(params['SNR_dB_min_test'], params['SNR_dB_max_test'],
              params['SNR_step_test'])
}
if params['isTest']:
    test_st = time.time()
    if 'DetNetDemod' in params['simulation_algorithms']:
        print('========正在测试%s检测方法, now_time is %s=========' %
              ('DetNetDemod', datetime.datetime.now().strftime('%X')))
        ser_DetNetDemods = model_eval(sess, params, DetNetDemod_nodes,
                                      gen_data, params['test_iterations'])
        results['ser_DetNetDemods'] = ser_DetNetDemods
        print(ser_DetNetDemods)

    if 'DetNetPIC' in params['simulation_algorithms']:
        print('========正在测试%s检测方法, now_time is %s=========' %
              ('DetNetPIC', datetime.datetime.now().strftime('%X')))
        ser_DetNetPIC1s = model_eval(sess, params, DetNetPIC_nodes1, gen_data,
                                     params['test_iterations'])
        ser_DetNetPIC2s = model_eval(sess, params, DetNetPIC_nodes2, gen_data,
                                     params['test_iterations'])
        results['ser_DetNetPIC1s'] = ser_DetNetPIC1s
        results['ser_DetNetPIC2s'] = ser_DetNetPIC2s
        print(ser_DetNetPIC1s, '\n', ser_DetNetPIC2s)

    if 'DetNetSIC' in params['simulation_algorithms']:
Exemplo n.º 7
0
model = models.get_model(args.model, n_cls, half_prec,
                         data.shapes_dict[args.dataset], args.n_filters_cnn,
                         args.n_hidden_fc)
model = model.cuda()
model_dict = torch.load('models/{}.pth'.format(args.model_path))
if args.early_stopped_model:
    model.load_state_dict(model_dict['best'])
else:
    model.load_state_dict(model_dict['last'] if 'last' in
                          model_dict else model_dict)

opt = torch.optim.SGD(model.parameters(), lr=0)  # needed for backprop only
if half_prec:
    model, opt = amp.initialize(model, opt, opt_level="O1")
utils.model_eval(model, half_prec)

eps, pgd_alpha, pgd_alpha_rr = eps / 255, pgd_alpha / 255, pgd_alpha_rr / 255

eval_batches_all = data.get_loaders(
    args.dataset,
    -1,
    args.batch_size_eval,
    train_set=True if args.set == 'train' else False,
    shuffle=False,
    data_augm=False)
eval_batches = data.get_loaders(
    args.dataset,
    n_eval,
    args.batch_size_eval,
    train_set=True if args.set == 'train' else False,
Exemplo n.º 8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Adversarial MNIST Example')
    parser.add_argument('--use-pretrained',
                        action='store_true',
                        default=False,
                        help='uses the pretrained model')
    parser.add_argument('--adversarial-training',
                        action='store_true',
                        default=False,
                        help='takes the adversarial training process')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for testing (default: 512)')
    args = parser.parse_args()

    # Define what device we are using
    use_cuda = torch.cuda.is_available()
    print("CUDA Available: ", use_cuda)
    device = torch.device("cuda" if use_cuda else "cpu")

    # MNIST Test dataset and dataloader declaration
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    # Initialize the network
    model = LeNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    if args.use_pretrained:
        print('Loading the pretrained model')
        model.load_state_dict(
            torch.load('resources/lenet_mnist_model.bin', map_location='cpu'))
    else:
        print('Training on the MNIST dataset')
        model_train(model, train_loader, F.nll_loss, optimizer, epochs=10)

    print('Evaluating the neural network')
    # Evaluate the accuracy of the MNIST model on clean examples
    accuracy, _ = model_eval(model, test_loader, F.nll_loss)
    print('Test accuracy on clean examples: ' + str(accuracy))

    # Evaluate the accuracy of the MNIST model on adversarial examples
    accuracy, _ = model_eval(model,
                             test_loader,
                             F.nll_loss,
                             attack_method=fgsm_attack)
    print('Test accuracy on adversarial examples: ' + str(accuracy))

    if args.adversarial_training:
        print("Repeating the process, with adversarial training")
        # Perform adversarial training
        model_train(model,
                    train_loader,
                    F.nll_loss,
                    optimizer,
                    epochs=10,
                    attack_method=fgsm_attack)

        # Evaluate the accuracy of the adversarially trained MNIST model on
        # clean examples
        accuracy, _ = model_eval(model, test_loader, F.nll_loss)
        print('Test accuracy on clean examples: ' + str(accuracy))

        # Evaluate the accuracy of the adversarially trained MNIST model on
        # adversarial examples
        accuracy_adv, _ = model_eval(model,
                                     test_loader,
                                     F.nll_loss,
                                     attack_method=fgsm_attack)
        print('Test accuracy on adversarial examples: ' + str(accuracy_adv))