Ejemplo n.º 1
0
def adversarial_accuracy(model,
                         dataset_loader,
                         device,
                         solvers=None,
                         solver_options=None,
                         args=None):
    model.eval()
    total_correct = 0
    if args.adv_testing_mode == "clean":
        test_attack = Clean(model)
    elif args.adv_testing_mode == "fgsm":
        test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TRAIN)
    elif args.adv_testing_mode == "at":
        test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TRAIN)
    else:
        raise ValueError("Attack type not understood.")
    for x, y in dataset_loader:
        x, y = x.to(device), y.to(device)
        x, y = test_attack(x, y, {
            "solvers": solvers,
            "solver_options": solver_options
        })
        y = one_hot(np.array(y.cpu().numpy()), 10)
        target_class = np.argmax(y, axis=1)
        with torch.no_grad():
            if solvers is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
        predicted_class = np.argmax(out, axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)
Ejemplo n.º 2
0
def train(model,
          data_gen,
          solvers,
          solver_options,
          criterion,
          optimizer,
          device,
          is_odenet=True,
          args=None):
    model.train()
    optimizer.zero_grad()
    x, y = data_gen.__next__()
    x = x.to(device)
    y = y.to(device)

    if args.adv_training_mode == "clean":
        train_attack = Clean(model)
    elif args.adv_training_mode == "fgsm":
        train_attack = FGSM(model, **CONFIG_FGSM_TRAIN)
    elif args.adv_training_mode == "at":
        train_attack = PGD(model, **CONFIG_PGD_TRAIN)
    else:
        raise ValueError("Attack type not understood.")
    x, y = train_attack(x, y, {
        "solvers": solvers,
        "solver_options": solver_options
    })

    # Add noise:
    if args.data_noise_std > 1e-12:
        with torch.no_grad():
            x = x + args.data_noise_std * torch.randn_like(x)
    ##### Forward pass
    if is_odenet:
        logits = model(x, solvers, solver_options,
                       Namespace(ss_loss=args.ss_loss))
    else:
        logits = model(x)

    xentropy = criterion(logits, y)
    if args.ss_loss:
        ss_loss = model.get_ss_loss()
        loss = xentropy + args.ss_loss_reg * ss_loss
    else:
        ss_loss = 0.
        loss = xentropy

    loss.backward()
    optimizer.step()
    if args.ss_loss:
        return {'xentropy': xentropy.item(), 'ss_loss': ss_loss.item()}
    return {'xentropy': xentropy.item()}
Ejemplo n.º 3
0
def run_attack(model, epsilons, attack_modes, loaders, device='cuda'):
    robust_accuracies = defaultdict(list)

    for mode in attack_modes:
        for epsilon in epsilons:
            #             CONFIG_PGD_TEST = {"eps": epsilon, "lr": 2.0 / 255 * 10, "n_iter": 20}
            CONFIG_PGD_TEST = {"eps": epsilon, "lr": 1, "n_iter": 5}
            CONFIG_FGSM_TEST = {"eps": epsilon}

            if mode == "clean":
                test_attack = Clean(model)
            elif mode == "fgsm":
                test_attack = FGSM(model, **CONFIG_FGSM_TEST)

            elif mode == "at":
                test_attack = PGD(model, **CONFIG_PGD_TEST)

            elif mode == "at_ls":
                test_attack = PGD(model,
                                  **CONFIG_PGD_TEST)  # wrong func, fix this

            elif mode == "av":
                test_attack = PGD(model,
                                  **CONFIG_PGD_TEST)  # wrong func, fix this

            elif mode == "fs":
                test_attack = PGD(model,
                                  **CONFIG_PGD_TEST)  # wrong func, fix this

            print("Attack {}".format(mode))
            test_metrics = test(loaders["val"],
                                model,
                                test_attack,
                                device,
                                show_progress=True)
            test_log = f"Test: | " + " | ".join(
                map(lambda x: f"{x[0]}: {x[1]:.6f}", test_metrics.items()))
            print(test_log)

            robust_accuracies['accuracy_{}'.format(mode)].append(
                test_metrics['accuracy_adv'])

    return robust_accuracies
Ejemplo n.º 4
0
def train(model,
          data_gen,
          solvers,
          solver_options,
          criterion,
          optimizer,
          device,
          is_odenet=True,
          iter=None,
          args=None):
    model.train()

    if (iter + 1) % args.zero_grad_every == 0:
        optimizer.zero_grad()

    x, y = data_gen.__next__()
    x = x.to(device)
    y = y.to(device)

    ### Noise params
    if args.noise_type is not None:
        for i in range(len(solvers)):
            solvers[i].u, solvers[i].v = noise_params(solvers[i].u0,
                                                      solvers[i].v0,
                                                      std=args.noise_sigma,
                                                      bernoulli_p=args.noise_prob,
                                                      noise_type=args.noise_type)
            solvers[i].build_ButcherTableau()

    global CONFIG_PGD_TRAIN
    global CONFIG_FGSM_TRAIN
    global CONFIG_FGSMRandom_TRAIN


    if args.adv_training_mode == "clean":
        train_attack = Clean(model)
    elif args.adv_training_mode == "fgsm":
        train_attack = FGSM(model, **CONFIG_FGSM_TRAIN)
    elif args.adv_training_mode == "fgsm_random":
        train_attack = FGSMRandom(model, **CONFIG_FGSMRandom_TRAIN)
    elif args.adv_training_mode == "at":
        train_attack = PGD(model, **CONFIG_PGD_TRAIN)
    else:
        raise ValueError("Attack type not understood.")
    x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options})

    ### Add noise:
    if args.data_noise_std > 1e-12:
        with torch.no_grad():
            x = x + args.data_noise_std * torch.randn_like(x)
    ### Forward pass
    if is_odenet:
        logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss))
    else:
        logits = model(x)

    xentropy = criterion(logits, y)
    if args.ss_loss:
        ss_loss = model.get_ss_loss()
        loss = xentropy + args.ss_loss_reg * ss_loss
    else:
        ss_loss = 0.
        loss = xentropy

    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()

    if args.grad_clipping_threshold:
        grad_norm = nn.utils.clip_grad_norm(amp.master_params(optimizer), args.grad_clipping_threshold)

    if (iter + 1) % args.zero_grad_every == 0:
        optimizer.step()

    ### Denoise params
    if args.noise_type is not None:
        for i in range(len(solvers)):
            solvers[i].u, solvers[i].v = solvers[i].u0, solvers[i].v0
            solvers[i].build_ButcherTableau()

    if args.ss_loss:
        return {'xentropy': xentropy.item(), 'ss_loss': ss_loss.item()}
    return {'xentropy': xentropy.item()}
Ejemplo n.º 5
0
def train(itr,
          model,
          data_gen,
          solvers,
          solver_options,
          criterion,
          optimizer,
          batch_time_meter,
          f_nfe_meter,
          b_nfe_meter,
          device = 'cpu',
          dtype = torch.float32,
          is_odenet = True,
          args = None,
          logger = None,
          wandb_logger = None):
    
    end = time.time()

    optimizer.zero_grad()
    x, y = data_gen.__next__()
    x = x.to(device)
    y = y.to(device)
    
    ##### Noise params
    if args.noise_type is not None:
        for i in range(len(solvers)):
            solvers[i].u, solvers[i].v = noise_params(solvers[i].u0,
                                                      solvers[i].v0,
                                                      std = args.noise_sigma,
                                                      bernoulli_p = args.noise_prob,
                                                      noise_type = args.noise_type)
            solvers[i].build_ButcherTableau()

    if args.adv_training_mode == "clean":
        train_attack = Clean(model)
    elif args.adv_training_mode == "fgsm":
        train_attack = FGSM(model, **CONFIG_FGSM_TRAIN)
    elif args.adv_training_mode == "at":
        train_attack = PGD(model, **CONFIG_PGD_TRAIN)
    else:
        raise ValueError("Attack type not understood.")
    x, y = train_attack(x, y, {"solvers": solvers, "solver_options": solver_options})

    # Add noise:
    if args.data_noise_std > 1e-12:
        with torch.no_grad():
            x = x + args.data_noise_std * torch.randn_like(x)
    ##### Forward pass
    if is_odenet:
        logits = model(x, solvers, solver_options, Namespace(ss_loss=args.ss_loss))
    else:
        logits = model(x)

    xentropy = criterion(logits, y)
    if args.ss_loss:
        ss_loss = model.get_ss_loss()
        loss = xentropy + args.ss_loss_reg * ss_loss
    else:
        ss_loss = 0.
        loss = xentropy
    if wandb_logger is not None:
        wandb_logger.log({"xentropy": xentropy.item(),
                    "ss_loss": ss_loss,
                    "loss": loss.item(),
                    "log_func": "train"})
    # if logger is not None:
    #     fix

    ##### Compute NFE-forward
    if is_odenet:
        nfe_forward = 0
        for i in range(len(model.blocks)):
            nfe_forward += model.blocks[i].rhs_func.nfe
            model.blocks[i].rhs_func.nfe = 0

    loss.backward()
    optimizer.step()

    ##### Compute NFE-backward
    if is_odenet:
        nfe_backward = 0
        for i in range(len(model.blocks)):
            nfe_backward += model.blocks[i].rhs_func.nfe
            model.blocks[i].rhs_func.nfe = 0

    ##### Denoise params
    if args.noise_type is not None:
        for i in range(len(solvers)):
            solvers[i].u, solvers[i].v = solvers[i].u0, solvers[i].v0
            solvers[i].build_ButcherTableau()

    batch_time_meter.update(time.time() - end)
    if is_odenet:
        f_nfe_meter.update(nfe_forward)
        b_nfe_meter.update(nfe_backward)