Example #1
0
def _main():
    if config.model_train:
        if not os.path.exists(config.val_path):
            x, y = load_data()
            print(len(x), len(y))
            x_train, x_val, y_train, y_val = train_test_split(
                x, y, test_size=config.ratio)
        else:
            x_train, y_train = load_data()
            x_val, y_val = load_data(val=True)
        model_train(x_train, y_train, x_val, y_val)
    else:
        x, y = load_data()
Example #2
0
loss_MMNet = []
print('===========MMNet is Training, now_time is %s=========' %
      (datetime.datetime.now().strftime('%X')))
train_st = time.time()
for n in range(params['nRounds']):
    for epoch in range(params['maxEpoch']):
        # x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(epoch, params['SNR_dB_train'])
        x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(
            0, params['SNR_dB_train'])  # 固定信道
        data_Feed = {
            'x': x_Feed,
            'y': y_Feed,
            'H': H_Feed,
            'noise_sigma2': noise_sigma2_Feed,
        }
        model_train(sess, MMNet_nodes, data_Feed)
        if epoch % 100 == 0:
            print('4', noise_sigma2_Feed)
            print('5', model_est(sess, MMNet_nodes, data_Feed))
            print('===========epoch%d,now_time is %s=========' %
                  (epoch + n * params['maxEpoch'],
                   datetime.datetime.now().strftime('%X')))
            ser_MMNet, loss = model_loss(sess, MMNet_nodes, data_Feed)
            print('ser_MMNet', ser_MMNet, 'loss=', loss, '\n')
            loss_MMNet.append(loss)
loss_all['MMNet'] = loss_MMNet
train_ed = time.time()
print("MMNet Train time is: " + str(train_ed - train_st))

# Testing
results = {
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    cur_timestamp = str(datetime.now())[:-3]  # we also include ms to prevent the probability of name collision
    model_width = {'linear': '', 'cnn': args.n_filters_cnn, 'lenet': '', 'resnet18': ''}[args.model]
    model_str = '{}{}'.format(args.model, model_width)
    model_name = '{} dataset={} model={} eps={} attack={} m={} attack_init={} fgsm_alpha={} epochs={} pgd={}-{} grad_align_cos_lambda={} lr_max={} seed={}'.format(
        cur_timestamp, args.dataset, model_str, args.eps, args.attack, args.minibatch_replay, args.attack_init, args.fgsm_alpha, args.epochs,
        args.pgd_alpha_train, args.pgd_train_n_iters, args.grad_align_cos_lambda, args.lr_max, args.seed)
    if not os.path.exists('models'):
        os.makedirs('models')
    logger = utils.configure_logger(model_name, args.debug)
    logger.info(args)
    half_prec = args.half_prec
    n_cls = 2 if 'binary' in args.dataset else 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    double_bp = True if args.grad_align_cos_lambda > 0 else False
    n_eval_every_k_iter = args.n_eval_every_k_iter
    args.pgd_alpha = args.eps / 4

    eps, pgd_alpha, pgd_alpha_train = args.eps / 255, args.pgd_alpha / 255, args.pgd_alpha_train / 255
    train_data_augm = False if args.dataset in ['mnist'] else True
    train_batches = data.get_loaders(args.dataset, -1, args.batch_size, train_set=True, shuffle=True, data_augm=train_data_augm)
    train_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size, train_set=True, shuffle=False, data_augm=False)
    test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)
    test_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)

    model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn).cuda()
    model.apply(utils.initialize_weights)
    model.train()

    if args.model == 'resnet18':
        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
    elif args.model == 'cnn':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    elif args.model == 'lenet':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    else:
        raise ValueError('decide about the right optimizer for the new model')

    if half_prec:
        if double_bp:
            amp.register_float_function(torch, 'batch_norm')
        model, opt = amp.initialize(model, opt, opt_level="O1")

    if args.attack == 'fgsm':  # needed here only for Free-AT
        delta = torch.zeros(args.batch_size, *data.shapes_dict[args.dataset][1:]).cuda()
        delta.requires_grad = True

    lr_schedule = utils.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_max)
    loss_function = nn.CrossEntropyLoss()

    train_acc_pgd_best, best_state_dict = 0.0, copy.deepcopy(model.state_dict())
    start_time = time.time()
    time_train, iteration, best_iteration = 0, 0, 0
    for epoch in range(args.epochs + 1):
        train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0
        for i, (X, y) in enumerate(train_batches):
            if i % args.minibatch_replay != 0 and i > 0:  # take new inputs only each `minibatch_replay` iterations
                X, y = X_prev, y_prev
            time_start_iter = time.time()
            # epoch=0 runs only for one iteration (to check the training stats at init)
            if epoch == 0 and i > 0:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches))  # epoch - 1 since the 0th epoch is skipped
            opt.param_groups[0].update(lr=lr)

            if args.attack in ['pgd', 'pgd_corner']:
                pgd_rs = True if args.attack_init == 'random' else False
                n_eps_warmup_epochs = 5
                n_iterations_max_eps = n_eps_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                eps_pgd_train = min(iteration / n_iterations_max_eps * eps, eps) if args.dataset == 'svhn' else eps
                delta = utils.attack_pgd_training(
                    model, X, y, eps_pgd_train, pgd_alpha_train, opt, half_prec, args.pgd_train_n_iters, rs=pgd_rs)
                if args.attack == 'pgd_corner':
                    delta = eps * utils.sign(delta)  # project to the corners
                    delta = clamp(X + delta, 0, 1) - X

            elif args.attack == 'fgsm':
                if args.minibatch_replay == 1:
                    if args.attack_init == 'zero':
                        delta = torch.zeros_like(X, requires_grad=True)
                    elif args.attack_init == 'random':
                        delta = utils.get_uniform_delta(X.shape, eps, requires_grad=True)
                    else:
                        raise ValueError('wrong args.attack_init')
                else:  # if Free-AT, we just reuse the existing delta from the previous iteration
                    delta.requires_grad = True

                X_adv = clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                if half_prec:
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        grad = torch.autograd.grad(scaled_loss, delta, create_graph=True if double_bp else False)[0]
                        grad /= scaled_loss / loss  # reverse back the scaling
                else:
                    grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]

                grad = grad.detach()

                argmax_delta = eps * utils.sign(grad)

                n_alpha_warmup_epochs = 5
                n_iterations_max_alpha = n_alpha_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                fgsm_alpha = min(iteration / n_iterations_max_alpha * args.fgsm_alpha, args.fgsm_alpha) if args.dataset == 'svhn' else args.fgsm_alpha
                delta.data = clamp(delta.data + fgsm_alpha * argmax_delta, -eps, eps)
                delta.data = clamp(X + delta.data, 0, 1) - X

            elif args.attack == 'random_corner':
                delta = utils.get_uniform_delta(X.shape, eps, requires_grad=False)
                delta = eps * utils.sign(delta)

            elif args.attack == 'none':
                delta = torch.zeros_like(X, requires_grad=False)
            else:
                raise ValueError('wrong args.attack')

            # extra FP+BP to calculate the gradient to monitor it
            if args.attack in ['none', 'random_corner', 'pgd', 'pgd_corner']:
                grad = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='none',
                                      backprop=args.grad_align_cos_lambda != 0.0)

            delta = delta.detach()

            output = model(X + delta)
            loss = loss_function(output, y)

            reg = torch.zeros(1).cuda()[0]  # for .item() to run correctly
            if args.grad_align_cos_lambda != 0.0:
                grad2 = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='random_uniform', backprop=True)
                grads_nnz_idx = ((grad**2).sum([1, 2, 3])**0.5 != 0) * ((grad2**2).sum([1, 2, 3])**0.5 != 0)
                grad1, grad2 = grad[grads_nnz_idx], grad2[grads_nnz_idx]
                grad1_norms, grad2_norms = l2_norm_batch(grad1), l2_norm_batch(grad2)
                grad1_normalized = grad1 / grad1_norms[:, None, None, None]
                grad2_normalized = grad2 / grad2_norms[:, None, None, None]
                cos = torch.sum(grad1_normalized * grad2_normalized, (1, 2, 3))
                reg += args.grad_align_cos_lambda * (1.0 - cos.mean())

            loss += reg

            if epoch != 0:
                opt.zero_grad()
                utils.backward(loss, opt, half_prec)
                opt.step()

            time_train += time.time() - time_start_iter
            train_loss += loss.item() * y.size(0)
            train_reg += reg.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

            with torch.no_grad():  # no grad for the stats
                grad_norm_x += l2_norm_batch(grad).sum().item()
                delta_final = clamp(X + delta, 0, 1) - X  # we should measure delta after the projection onto [0, 1]^d
                avg_delta_l2 += ((delta_final ** 2).sum([1, 2, 3]) ** 0.5).sum().item()

            if iteration % args.eval_iter_freq == 0:
                train_loss, train_reg = train_loss / train_n, train_reg / train_n
                train_acc, avg_delta_l2 = train_acc / train_n, avg_delta_l2 / train_n

                # it'd be incorrect to recalculate the BN stats on the test sets and for clean / adversarial points
                utils.model_eval(model, half_prec)

                test_acc_clean, _, _ = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_fgsm, test_loss_fgsm, fgsm_deltas = rob_acc(test_batches_fast, model, eps, eps, opt, half_prec, 1, 1, rs=False)
                test_acc_pgd, test_loss_pgd, pgd_deltas = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)
                cos_fgsm_pgd = utils.avg_cos_np(fgsm_deltas, pgd_deltas)
                train_acc_pgd, _, _ = rob_acc(train_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)  # needed for early stopping

                grad_x = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=False)
                grad_eta = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=True)
                cos_x_eta = utils.avg_cos_np(grad_x, grad_eta)

                time_elapsed = time.time() - start_time
                train_str = '[train] loss {:.3f}, reg {:.3f}, acc {:.2%} acc_pgd {:.2%}'.format(train_loss, train_reg, train_acc, train_acc_pgd)
                test_str = '[test] acc_clean {:.2%}, acc_fgsm {:.2%}, acc_pgd {:.2%}, cos_x_eta {:.3}, cos_fgsm_pgd {:.3}'.format(
                    test_acc_clean, test_acc_fgsm, test_acc_pgd, cos_x_eta, cos_fgsm_pgd)
                logger.info('{}-{}: {}  {} ({:.2f}m, {:.2f}m)'.format(epoch, iteration, train_str, test_str,
                                                                      time_train/60, time_elapsed/60))

                if train_acc_pgd > train_acc_pgd_best:  # catastrophic overfitting can be detected on the training set
                    best_state_dict = copy.deepcopy(model.state_dict())
                    train_acc_pgd_best, best_iteration = train_acc_pgd, iteration

                utils.model_train(model, half_prec)
                train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0

            iteration += 1
            X_prev, y_prev = X.clone(), y.clone()  # needed for Free-AT

        if epoch == args.epochs:
            torch.save({'last': model.state_dict(), 'best': best_state_dict}, 'models/{} epoch={}.pth'.format(model_name, epoch))
            # disable global conversion to fp16 from amp.initialize() (https://github.com/NVIDIA/apex/issues/567)
            context_manager = amp.disable_casts() if half_prec else utils.nullcontext()
            with context_manager:
                last_state_dict = copy.deepcopy(model.state_dict())
                half_prec = False  # final eval is always in fp32
                model.load_state_dict(last_state_dict)
                utils.model_eval(model, half_prec)
                opt = torch.optim.SGD(model.parameters(), lr=0)

                attack_iters, n_restarts = (50, 10) if not args.debug else (10, 3)
                test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                logger.info('[last: test on 10k points] acc_clean {:.2%}, pgd_rr {:.2%}'.format(test_acc_clean, test_acc_pgd_rr))

                if args.eval_early_stopped_model:
                    model.load_state_dict(best_state_dict)
                    utils.model_eval(model, half_prec)
                    test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                    test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                    logger.info('[best: test on 10k points][iter={}] acc_clean {:.2%}, pgd_rr {:.2%}'.format(
                        best_iteration, test_acc_clean, test_acc_pgd_rr))

        utils.model_train(model, half_prec)

    logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
loss_all = []
print('=========== is Training, now_time is %s=========' %
      (datetime.datetime.now().strftime('%X')))
train_st = time.time()
for n in range(params['nRounds']):
    for epoch in range(params['maxEpoch']):
        # x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(epoch, params['SNR_dB_train'])
        x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(
            epoch, params['SNR_dB_train'])  # 固定信道
        data_Feed = {
            'x': x_Feed,
            'y': y_Feed,
            'H': H_Feed,
            'noise_sigma2': noise_sigma2_Feed,
        }
        model_train(sess, DetNetSIC2_nodes, data_Feed)
        if epoch % 100 == 0:
            print('===========epoch%d,now_time is %s=========' %
                  (epoch + n * params['maxEpoch'],
                   datetime.datetime.now().strftime('%X')))
            ser, loss = model_loss(sess, DetNetSIC2_nodes, data_Feed)
            print('ser', ser, 'loss=', loss, '\n')
            loss_all.append(loss)
train_ed = time.time()
print("Train time is: " + str(train_ed - train_st))

# Testing
results = {
    'SNR_dBs':
    np.arange(params['SNR_dB_min_test'], params['SNR_dB_max_test'],
              params['SNR_step_test'])
Example #5
0
def main():
    parser = argparse.ArgumentParser(
        description='Test',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('data', metavar='DATA', help='path to file')

    parser.add_argument('--tb-save-path',
                        dest='tb_save_path',
                        metavar='PATH',
                        default='../checkpoints/',
                        help='tensorboard checkpoints path')

    parser.add_argument('--weight-save-path',
                        dest='weight_save_path',
                        metavar='PATH',
                        default='../weights/',
                        help='weight checkpoints path')

    parser.add_argument('--pretrained-weight',
                        dest='weight',
                        metavar='PATH',
                        default=None,
                        help='pretrained weight')

    parser.add_argument('--activation',
                        dest='activation',
                        metavar='activation',
                        default='relu',
                        help='activation of network; \'relu\' or \'sin\'')

    parser.add_argument('--batchsize',
                        dest='batchsize',
                        type=int,
                        metavar='BATCHSIZE',
                        default=1,
                        help='batch size')
    parser.add_argument('--epoch',
                        dest='epoch',
                        type=int,
                        metavar='EPOCH',
                        default=100,
                        help='epochs')

    parser.add_argument(
        '--abs',
        dest='abs',
        type=bool,
        metavar='BOOL',
        default=False,
        help='whether we should use ABS when evaluating normal loss')

    parser.add_argument('--epsilon',
                        dest='epsilon',
                        type=float,
                        metavar='EPSILON',
                        default=0.1,
                        help='epsilon')
    parser.add_argument('--lambda',
                        dest='lamb',
                        type=float,
                        metavar='LAMBDA',
                        default=0.005,
                        help='hyperparameter for s : normal loss ratio')

    parser.add_argument('--outfile',
                        dest='outfile',
                        metavar='OUTFILE',
                        help='output file')

    args = parser.parse_args()

    writer = SummaryWriter(args.tb_save_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # create models
    model = Siren(in_features=3,
                  out_features=1,
                  hidden_features=256,
                  hidden_layers=5,
                  outermost_linear=True).to(device)

    if args.weight != None:
        try:
            model.load_state_dict(torch.load(args.weight))
        except:
            print("Couldn't load pretrained weight: " + args.weight)

    # load
    ds = ObjDataset(args.data)
    samples_n = 20000

    augments = nn.Sequential(ObjUniformSample(samples_n),
                             NormalPerturb(args.epsilon),
                             RandomAugment(samples_n, args.epsilon * 0.5)(ds))

    ds = augments(ds)

    p_aug = ds['p'].detach_().to(device)
    n_aug = ds['n'].detach_().to(device)
    s_aug = ds['s'].detach_().to(device)

    p = p_aug[:samples_n]
    n = n_aug[:samples_n]

    p_gt = p.repeat(2, 1)

    writer.add_mesh("1. n_gt",
                    p.unsqueeze(0),
                    colors=(n.unsqueeze(0) * 128 + 128).int())

    optimizer = optim.Adam(list(model.parameters()), lr=1e-4)

    for epoch in range(args.epoch):
        optimizer.zero_grad()

        # train
        utils.model_train(model)
        loss_t, s, n = train(device,
                             model,
                             p_aug,
                             s_aug,
                             n_aug,
                             backward=True,
                             lamb=args.lamb,
                             use_abs=args.abs)

        #loss_x = 1e2 * torch.sum(torch.pow(p_aug - p_gt, 2))
        #loss_x.backward()

        #writer.add_scalars("loss", {'train': loss_t + loss_x.detach()}, epoch)

        # visualization
        with torch.no_grad():

            n_normalized = n / torch.norm(n, dim=1, keepdim=True)

            n_error = torch.sum(n_normalized * n_aug, dim=1,
                                keepdim=True) / torch.norm(
                                    n_aug, dim=1, keepdim=True)

            n_error_originals = n_error[:p.shape[0]]

            writer.add_scalars(
                "cosine similarity", {
                    'train':
                    n_error_originals[~torch.isnan(n_error_originals)].detach(
                    ).mean()
                }, epoch)

            if epoch % 10 == 0:
                print(epoch)
                writer.add_mesh(
                    "2. n",
                    p_aug[:p.shape[0]].unsqueeze(0).detach().clone(),
                    colors=(n_normalized[:p.shape[0]].unsqueeze(
                        0).detach().clone() * 128 + 128).int(),
                    global_step=epoch)

                writer.add_mesh(
                    "3. cosine similarity",
                    p_aug[:p.shape[0]].unsqueeze(0).detach().clone(),
                    colors=(F.pad(1 - n_error[:p.shape[0]],
                                  (0, 2)).unsqueeze(0).detach().clone() *
                            256).int(),
                    global_step=epoch)

        # update
        optimizer.step()

        torch.save(model.state_dict(),
                   args.weight_save_path + 'model_%03d.pth' % epoch)

    writer.close()
Example #6
0
if 'DetNet' in params['simulation_algorithms']:
    loss_DetNet = []
    print('===========DetNet is Training, now_time is %s=========' %
          (datetime.datetime.now().strftime('%X')))
    train_st = time.time()
    for n in range(params['nRounds']):
        for epoch in range(params['maxEpoch']):
            x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(
                epoch, params['SNR_dB_train'])
            data_Feed = {
                'x': x_Feed,
                'y': y_Feed,
                'H': H_Feed,
                'noise_sigma2': noise_sigma2_Feed,
            }
            model_train(sess, DetNet_nodes, data_Feed)
            if epoch % 1000 == 0:
                print('===========epoch%d,now_time is %s=========' %
                      (epoch + n * params['maxEpoch'],
                       datetime.datetime.now().strftime('%X')))
                ser_DetNet, loss = model_loss(sess, DetNet_nodes, data_Feed)
                print('ser_DetNet', ser_DetNet, 'loss=', loss, '\n')
                loss_DetNet.append(loss)
    loss_all['DetNet'] = loss_DetNet  # 损失值
    train_ed = time.time()
    print("DetNet Train time is: " + str(train_ed - train_st))

if 'OAMPNet' in params['simulation_algorithms']:
    loss_OAMPNet = []
    print('===========OAMPNet is Training, now_time is %s=========' %
          (datetime.datetime.now().strftime('%X')))
Example #7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Adversarial MNIST Example')
    parser.add_argument('--use-pretrained',
                        action='store_true',
                        default=False,
                        help='uses the pretrained model')
    parser.add_argument('--adversarial-training',
                        action='store_true',
                        default=False,
                        help='takes the adversarial training process')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for testing (default: 512)')
    args = parser.parse_args()

    # Define what device we are using
    use_cuda = torch.cuda.is_available()
    print("CUDA Available: ", use_cuda)
    device = torch.device("cuda" if use_cuda else "cpu")

    # MNIST Test dataset and dataloader declaration
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    # Initialize the network
    model = LeNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    if args.use_pretrained:
        print('Loading the pretrained model')
        model.load_state_dict(
            torch.load('resources/lenet_mnist_model.bin', map_location='cpu'))
    else:
        print('Training on the MNIST dataset')
        model_train(model, train_loader, F.nll_loss, optimizer, epochs=10)

    print('Evaluating the neural network')
    # Evaluate the accuracy of the MNIST model on clean examples
    accuracy, _ = model_eval(model, test_loader, F.nll_loss)
    print('Test accuracy on clean examples: ' + str(accuracy))

    # Evaluate the accuracy of the MNIST model on adversarial examples
    accuracy, _ = model_eval(model,
                             test_loader,
                             F.nll_loss,
                             attack_method=fgsm_attack)
    print('Test accuracy on adversarial examples: ' + str(accuracy))

    if args.adversarial_training:
        print("Repeating the process, with adversarial training")
        # Perform adversarial training
        model_train(model,
                    train_loader,
                    F.nll_loss,
                    optimizer,
                    epochs=10,
                    attack_method=fgsm_attack)

        # Evaluate the accuracy of the adversarially trained MNIST model on
        # clean examples
        accuracy, _ = model_eval(model, test_loader, F.nll_loss)
        print('Test accuracy on clean examples: ' + str(accuracy))

        # Evaluate the accuracy of the adversarially trained MNIST model on
        # adversarial examples
        accuracy_adv, _ = model_eval(model,
                                     test_loader,
                                     F.nll_loss,
                                     attack_method=fgsm_attack)
        print('Test accuracy on adversarial examples: ' + str(accuracy_adv))
sess = tf.compat.v1.Session(config=config)
sess.run(init)

# Training
print('=========== is Training, now_time is %s=========' % (datetime.datetime.now().strftime('%X')))
train_st = time.time()
for n in range(params['nRounds']):
    for epoch in range(params['maxEpoch']):
        x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(epoch, params['SNR_dB_train'])
        data_Feed = {
            'x': x_Feed,
            'y': y_Feed,
            'H': H_Feed,
            'noise_sigma2': noise_sigma2_Feed,
        }
        model_train(sess, DetNetSIC3_node1, data_Feed)
        if epoch % 100 == 0:
            print('===========epoch%d,now_time is %s=========' % (epoch + n*params['maxEpoch'],
                                                                  datetime.datetime.now().strftime('%X')))
            ser, loss = model_loss(sess, DetNetSIC3_node1, data_Feed)
            print('ser', ser, 'loss=', loss, '\n')
train_ed = time.time()
print("Train time is: "+str(train_ed-train_st))

train_st = time.time()
for n in range(params['nRounds']):
    for epoch in range(params['maxEpoch']):
        x_Feed, H_Feed, y_Feed, noise_sigma2_Feed = gen_data.dataTrain(epoch, params['SNR_dB_train'])
        data_Feed = {
            'x': x_Feed,
            'y': y_Feed,