コード例 #1
0
    def __init__(self,
                 ds_name,
                 ds_path,
                 lr,
                 iterations,
                 batch_size,
                 print_freq,
                 k,
                 eps,
                 is_normalized,
                 adv_momentum,
                 store_adv=None,
                 load_adv_dir=None,
                 load_adv_name=None,
                 load_dir=None,
                 load_name=None,
                 save_dir=None):

        self.data_processor = Preprocessor(ds_name, ds_path, is_normalized)

        # Load Data
        self.train_data, self.test_data, self.N_train, self.N_test = self.data_processor.datasets(
        )
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=batch_size,
                                       shuffle=True)
        self.test_loader = DataLoader(self.test_data, batch_size=batch_size)

        # Other Variables
        self.save_dir = save_dir
        self.store_adv = store_adv

        # Set Model Hyperparameters
        self.learning_rate = lr
        self.iterations = iterations
        self.print_freq = print_freq
        self.cuda = torch.cuda.is_available()

        # Load Model to Conduct Adversarial Training
        adversarial_model = self.load_model(self.cuda, load_adv_dir,
                                            load_adv_name, TEST)
        self.adversarial_generator = Attacks(adversarial_model, eps,
                                             self.N_train, self.N_test,
                                             self.data_processor.get_const(),
                                             adv_momentum, is_normalized,
                                             store_adv)

        # Load Target Model
        self.target_model = self.load_model(self.cuda, load_dir, load_name,
                                            TEST)

        # Load Denoiser
        self.denoiser = Denoiser(x_h=32, x_w=32)
        self.denoiser = self.denoiser.cuda()
コード例 #2
0
def handle_attacks(credentials):
    """
    awspx attacks
    """

    include_conditional_attacks = False
    skip_attacks = []
    only_attacks = []

    max_attack_iterations = 5
    max_attack_depth = None

    attacks = Attacks(
        skip_conditional_actions=include_conditional_attacks == False,
        skip_attacks=skip_attacks,
        only_attacks=only_attacks,
        credentials=credentials)

    attacks.compute(
        max_iterations=max_attack_iterations,
        max_search_depth=str(
            max_attack_depth if max_attack_depth is not None else ""))
コード例 #3
0
    def __init__(self,
                 ds_name,
                 ds_path,
                 lr,
                 iterations,
                 batch_size,
                 print_freq,
                 k,
                 eps,
                 adv_momentum,
                 train_transform_fn,
                 test_transform_fn,
                 is_normalized,
                 store_adv=False,
                 load_dir=None,
                 load_name=None,
                 load_adv_dir=None,
                 load_adv_name=None,
                 save_dir=None,
                 attack=MODE_PLAIN,
                 train_mode=RAW,
                 test_mode=RAW,
                 mode=TRAIN_AND_TEST):

        # Load Data
        if ds_name == 'CIFAR10':
            self.train_data = torchvision.datasets.CIFAR10(
                ds_path,
                train=True,
                transform=train_transform_fn(),
                download=True)
            self.test_data = torchvision.datasets.CIFAR10(
                ds_path,
                train=False,
                transform=test_transform_fn(),
                download=True)

        # collate_fn
        self.train_loader = torch.utils.data.DataLoader(self.train_data,
                                                        batch_size=batch_size,
                                                        shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(self.test_data,
                                                       batch_size=batch_size)

        # Other Variables
        self.save_dir = save_dir
        self.store_adv = store_adv
        self.train_raw = (train_mode == RAW or train_mode == BOTH)
        self.train_adv = (train_mode == ADV or train_mode == BOTH)
        self.test_raw = (test_mode == RAW or test_mode == BOTH)
        self.test_adv = (test_mode == ADV or test_mode == BOTH)

        # Set Model Hyperparameters
        self.learning_rate = lr
        self.iterations = iterations
        self.print_freq = print_freq

        self.cuda = torch.cuda.is_available()

        # Load model for training
        self.model = self.load_model(self.cuda, load_dir, load_name, mode)

        # Define attack method to generate adversaries.
        if self.train_adv:

            # Load pre-trained model
            adversarial_model = self.load_model(self.cuda, load_adv_dir,
                                                load_adv_name, TEST)

            # Define adversarial generator model
            self.adversarial_generator = Attacks(adversarial_model, eps,
                                                 len(self.train_data),
                                                 len(self.test_data),
                                                 adv_momentum, is_normalized,
                                                 store_adv)

            self.attack_fn = None
            if attack == MODE_PGD:
                self.attack_fn = self.adversarial_generator.fast_pgd
            elif attack == MODE_CW:
                self.attack_fn = self.adversarial_generator.carl_wagner
コード例 #4
0
class Classifier:
    """
    
    """
    def __init__(self,
                 ds_name,
                 ds_path,
                 lr,
                 iterations,
                 batch_size,
                 print_freq,
                 k,
                 eps,
                 adv_momentum,
                 train_transform_fn,
                 test_transform_fn,
                 is_normalized,
                 store_adv=False,
                 load_dir=None,
                 load_name=None,
                 load_adv_dir=None,
                 load_adv_name=None,
                 save_dir=None,
                 attack=MODE_PLAIN,
                 train_mode=RAW,
                 test_mode=RAW,
                 mode=TRAIN_AND_TEST):

        # Load Data
        if ds_name == 'CIFAR10':
            self.train_data = torchvision.datasets.CIFAR10(
                ds_path,
                train=True,
                transform=train_transform_fn(),
                download=True)
            self.test_data = torchvision.datasets.CIFAR10(
                ds_path,
                train=False,
                transform=test_transform_fn(),
                download=True)

        # collate_fn
        self.train_loader = torch.utils.data.DataLoader(self.train_data,
                                                        batch_size=batch_size,
                                                        shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(self.test_data,
                                                       batch_size=batch_size)

        # Other Variables
        self.save_dir = save_dir
        self.store_adv = store_adv
        self.train_raw = (train_mode == RAW or train_mode == BOTH)
        self.train_adv = (train_mode == ADV or train_mode == BOTH)
        self.test_raw = (test_mode == RAW or test_mode == BOTH)
        self.test_adv = (test_mode == ADV or test_mode == BOTH)

        # Set Model Hyperparameters
        self.learning_rate = lr
        self.iterations = iterations
        self.print_freq = print_freq

        self.cuda = torch.cuda.is_available()

        # Load model for training
        self.model = self.load_model(self.cuda, load_dir, load_name, mode)

        # Define attack method to generate adversaries.
        if self.train_adv:

            # Load pre-trained model
            adversarial_model = self.load_model(self.cuda, load_adv_dir,
                                                load_adv_name, TEST)

            # Define adversarial generator model
            self.adversarial_generator = Attacks(adversarial_model, eps,
                                                 len(self.train_data),
                                                 len(self.test_data),
                                                 adv_momentum, is_normalized,
                                                 store_adv)

            self.attack_fn = None
            if attack == MODE_PGD:
                self.attack_fn = self.adversarial_generator.fast_pgd
            elif attack == MODE_CW:
                self.attack_fn = self.adversarial_generator.carl_wagner

    def load_model(self, is_cuda, load_dir=None, load_name=None, mode=None):
        """ Return WideResNet model, in gpu if applicable, and with provided checkpoint if given"""
        model = WideResNet(depth=28,
                           num_classes=10,
                           widen_factor=10,
                           dropRate=0.0)

        # Send to GPU if any
        if is_cuda:
            model = torch.nn.DataParallel(model).cuda()
            print(">>> SENDING MODEL TO GPU...")

        # Load checkpoint
        if load_dir and load_name and mode == TEST:
            model = self.load_checkpoint(model, load_dir, load_name)
            print(">>> LOADING PRE-TRAINED MODEL...")

        return model

    def train_step(self, x_batch, y_batch, optimizer, losses, top1, k=1):
        """ Performs a step during training. """
        # Compute output for example
        logits = self.model(x_batch)
        loss = self.model.module.loss(logits, y_batch)

        # Update Mean loss for current iteration
        losses.update(loss.item(), x_batch.size(0))
        prec1 = accuracy(logits.data, y_batch, k=k)
        top1.update(prec1.item(), x_batch.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # Set grads to zero for new iter
        optimizer.zero_grad()

    def test_step(self, x_batch, y_batch, losses, top1, k=1):
        """ Performs a step during testing."""
        with torch.no_grad():
            logits = self.model(x_batch)
            loss = self.model.module.loss(logits, y_batch)

        # Update Mean loss for current iteration
        losses.update(loss.item(), x_batch.size(0))
        prec1 = accuracy(logits.data, y_batch, k=k)
        top1.update(prec1.item(), x_batch.size(0))

    def train(self,
              momentum,
              nesterov,
              weight_decay,
              train_max_iter=1,
              test_max_iter=1):

        train_loss_hist = []
        train_acc_hist = []
        test_loss_hist = []
        test_acc_hist = []

        best_pred = 0.0

        end = time.time()

        for itr in range(self.iterations):

            self.model.train()

            optimizer = optim.SGD(self.model.parameters(),
                                  lr=compute_lr(self.learning_rate, itr),
                                  momentum=momentum,
                                  nesterov=nesterov,
                                  weight_decay=weight_decay)

            losses = AverageMeter()
            batch_time = AverageMeter()
            top1 = AverageMeter()

            x_adv = None

            for i, (x, y) in enumerate(self.train_loader):

                x = x.cuda()
                y = y.cuda()

                # Train raw examples
                if self.train_raw:
                    self.train_step(x, y, optimizer, losses, top1)

                # Train adversarial examples if applicable
                if self.train_adv:
                    x_adv, y_adv = self.attack_fn(x,
                                                  y,
                                                  train_max_iter,
                                                  mode='train')
                    self.train_step(x_adv, y_adv, optimizer, losses, top1)

                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.print_freq == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                              itr,
                              i,
                              len(self.train_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1))

            # Evaluate on validation set
            test_loss, test_prec1 = self.test(self.test_loader, test_max_iter)

            train_loss_hist.append(losses.avg)
            train_acc_hist.append(top1.avg)
            test_loss_hist.append(test_loss)
            test_acc_hist.append(test_prec1)

            # Store best model
            is_best = best_pred < test_prec1
            self.save_checkpoint(is_best, (itr + 1), self.model.state_dict(),
                                 self.save_dir)
            if is_best:
                best_pred = test_prec1

            # Adversarial examples generated on the first iteration. Store them if re-using same iteration ones.
            if self.train_adv and self.store_adv:
                self.adversarial_generator.set_stored('train', True)

        return (train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist)

    def test(self, batch_loader, test_max_iter=1):
        self.model.eval()

        losses = AverageMeter()
        batch_time = AverageMeter()
        top1 = AverageMeter()

        end = time.time()

        for i, (x, y) in enumerate(batch_loader):

            x = x.cuda()
            y = y.cuda()

            # Test on adversarial
            if self.test_raw:
                self.test_step(x, y, losses, top1)

            # Test on adversarial examples
            if self.test_adv:
                x_adv, y_adv = self.attack_fn(x, y, test_max_iter, mode='test')
                self.test_step(x_adv, y_adv, losses, top1)

            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                print('Epoch: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i,
                          len(batch_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1))

        # Test adversarial examples generated on the first iteration. No need to compute them again.
        if self.test_adv and self.store_adv:
            self.adversarial_generator.set_stored('test', True)

        print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
        return (losses.avg, top1.avg)

    def save_checkpoint(self,
                        is_best,
                        epoch,
                        state,
                        save_dir,
                        base_name="chkpt_plain"):
        """Saves checkpoint to disk"""
        directory = save_dir
        filename = base_name + ".pth.tar"
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = directory + filename
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            directory + base_name + '__model_best.pth.tar')

    def load_checkpoint(self, model, load_dir, load_name):
        """Load checkpoint from disk"""
        filepath = load_dir + load_name
        if os.path.exists(filepath):
            state_dict = torch.load(filepath)
            model.load_state_dict(state_dict)
            print("Loaded checkpoint...")
            return model

        print("Failed to load model. Exiting...")
        sys.exit(1)
コード例 #5
0
    def test():
        # set the evaluation mode
        model.eval()
        n = len(args.metric_test)
        # test loss for the data
        test_loss = [0]*n

        # we don't need to track the gradients, since we are not updating the parameters during evaluation / testing
        #with torch.no_grad():
        for i, (x, label) in enumerate(test_loader):
            # reshape the data
            # x = x.view(-1, 28 * 28)
            x = x.to(device)
            if args.model == "fc":

                x = flatten(x)
            for j, m in enumerate(args.metric_test):
                if m == "original":
                # forward pass
                    x_recon, mu, var = model(x)

                    if args.model == "cnn":
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)
                    elif args.model == "fc":
                        loss = loss_fn_fc(x_recon, x, mu, var, args.KLweight)

                    test_loss[j] += loss.item()
                elif m == "FGSM":

                    if args.model == "fc":
                        x_flat = flatten(x)
                        # print(images.shape) #128*28^2
                        # create adv images.perform step on the reconstructed image generated by adv images.

                        model.eval()
                        adv = Attacks.fgsm_untargeted(model, x_flat, label=7, eps=0.15, clip_min=0,
                                                    clip_max=1.0)
                        # model.train()
                        # print("adv.shape ",adv.shape)

                        recon_images, mu, logvar = model(adv)
                        # print(recon_images.shape)
                        adv = adv.detach()
                        loss = loss_fn_fc(recon_images, x_flat, mu, logvar, args.KLweight)

                    if args.model == "cnn":
                        adv = Attacks.fgsm_untargeted(model, x, label=7,  eps=0.15, is_fc=False,
                                                    clip_min=0,
                                                    clip_max=1.0)  # do something here.
                        x_recon, mu, var = model(adv)
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)

                    test_loss[j] += loss.item()
                elif m == "PGD":
                    if args.model == "fc":

                        x_flat = flatten(x)
                        # print(images.shape) #128*28^2
                        # create adv images.perform step on the reconstructed image generated by adv images.

                        model.eval()
                        adv = Attacks.pgd_untargeted(model, x_flat, label=7, k=10, eps=0.08, eps_step=0.05, clip_min=0,
                                             clip_max=1.0)
                        # model.train()
                        # print("adv.shape ",adv.shape)

                        recon_images, mu, logvar = model(adv)
                        # print(recon_images.shape)
                        adv = adv.detach()
                        loss = loss_fn_fc(recon_images, x_flat, mu, logvar,args.KLweight)

                    if args.model == "cnn":
                        adv = Attacks.pgd_untargeted(model, x, label=7, k=10, eps=0.08, eps_step=0.05,is_fc =False,
                                                    clip_min=0,
                                                    clip_max=1.0) # do something here.
                        x_recon, mu, var = model(adv)
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)
                elif m == "GNA":
                    if args.model == "fc":

                        # create adv images.perform step on the reconstructed image generated by adv images.

                        model.eval()
                        adv = Attacks.GNA(model, x,clip_min=0, clip_max=1.0, _sigma = 0.1,_mean=0.1,rand=True)
                        x_flat = flatten(x)
                        # model.train()
                        # print("adv.shape ",adv.shape)

                        recon_images, mu, logvar = model(adv)
                        # print(recon_images.shape)
                        adv = adv.detach()
                        loss = loss_fn_fc(recon_images, x_flat, mu, logvar,args.KLweight)

                    if args.model == "cnn":
                        adv = Attacks.GNA(model, x,clip_min=0, clip_max=1.0, _sigma = 0.1,_mean=0.1,rand=True)
                        
                        x_recon, mu, var = model(adv)
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)
                        
                elif m == "GTA":
                    if args.model == "fc":

                        # create adv images.perform step on the reconstructed image generated by adv images.

                        model.eval()
                        adv =Attacks.GTA(model, x, rotate=30,translate=([0.01,0.01]),scale = ([0.8,1.2]))
                        x_flat = flatten(x)
                        # model.train()
                        # print("adv.shape ",adv.shape)

                        recon_images, mu, logvar = model(adv)
                        # print(recon_images.shape)
                        adv = adv.detach()
                        loss = loss_fn_fc(recon_images, x_flat, mu, logvar,args.KLweight)

                    if args.model == "cnn":
                        adv = Attacks.GTA(model, x, rotate=30,translate=([0.01,0.01]),scale = ([0.8,1.2]))
                        
                        x_recon, mu, var = model(adv)
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)

                    test_loss[j] += loss.item()
                elif m == "VAEA":
                    if args.model == "fc":

                        # create adv images.perform step on the reconstructed image generated by adv images.
                        norm = args.loss
                        model.eval()
                        
                        (target, _) = DS.sample_adv_untargeted_quick(quick_targets_dataset, source=x, label=label)
                        A = Attacks(model)
                        adv = A.VAEA(model, x, target, chosen_norm=norm, steps=10, eps=0.08, eps_norm=norm)
                        x_flat = flatten(x)
                        # print("adv.shape ",adv.shape)

                        recon_images, mu, logvar = model(adv)
                        # print(recon_images.shape)
                        adv = adv.detach()
                        loss = loss_fn_fc(recon_images, x_flat, mu, logvar,args.KLweight)

                    if args.model == "cnn":
                        adv = Attacks.GNA(model, x,clip_min=0, clip_max=1.0, _sigma = 0.1,_mean=0.1,rand=True)
                        
                        x_recon, mu, var = model(adv)
                        loss = model.loss_fn(x_recon, x, mu, var, args.KLweight)
        return test_loss
コード例 #6
0
 def __init__(self):
   self.config = self.load_config()
   self.interface = Interface(self.config['interface'])
   self.tools = Tools(self.interface, self.config['tools'])
   self.attacks = Attacks(self.tools, self.config['attacks'])
   self.tools.stop_interfaces()
コード例 #7
0
class Main:

  def __init__(self):
    self.config = self.load_config()
    self.interface = Interface(self.config['interface'])
    self.tools = Tools(self.interface, self.config['tools'])
    self.attacks = Attacks(self.tools, self.config['attacks'])
    self.tools.stop_interfaces()

  def load_config(self):
    with open('data/config/config.json') as f:
      return json.load(f)

  def main_interface(self):
    return self.interface.show('EVIL-ESP', [
      Voice('Options', self.options_interface),
      Voice('Attacks', self.attacks_interface),
      Voice('Logs', self.logs_interface)
    ])

  def options_interface(self):
    return self.interface.show('OPTIONS', [
      Voice('Back', self.main_interface),
      Voice('Config Mode', self.tools.start_config_mode),
      Voice('Sleep', self.device_sleep)
    ])

  def attacks_interface(self):
    return self.interface.show('ATTACKS', [
      Voice('Back', self.main_interface),
      Voice('Beacon Spammer', self.attacks.start_beacon_spammer),
      Voice('Captive Portal', self.attacks.start_captive_portal),
      Voice('Evil Twin', self.evil_twin_interface)
    ])

  def create_lambda(self, essid):
    return lambda: self.attacks.start_evil_twin(essid)

  def evil_twin_interface(self):
    self.interface.show_single(['Scanning'])
    voices = [Voice('No networks', self.attacks_interface)]
    networks = self.tools.scan_networks()
    if len(networks) > 0:
      voices[0].name = 'Back'
      for net in networks:
        if net[4] > 0:
          essid = str(net[0], 'utf-8')
          callback = self.create_lambda(essid)
          voices.append(Voice(essid, callback))
    return self.interface.show('TWIN', voices)

  def logs_interface(self):
    return self.interface.show('LOGS', [
      Voice('Back', self.main_interface),
      Voice('Captive Portal', self.captive_portal_logs_interface),
      Voice('Evil Twin', self.evil_twin_logs_interface)
    ])

  def captive_portal_logs_interface(self):
    voices = [Voice('No credentials', self.logs_interface)]
    logs = self.attacks.get_captive_portal_logs()
    if len(logs) > 0:
      voices = []
      for log in logs:
        log = log.replace('\n', '')
        voices.append(Voice(log, self.logs_interface))
    return self.interface.show('PORTAL', voices)

  def evil_twin_logs_interface(self):
    voices = [Voice('No credentials', self.logs_interface)]
    logs = self.attacks.get_evil_twin_logs()
    if len(logs) > 0:
      voices = []
      for log in logs:
        log = log.replace('\n', '')
        voices.append(Voice(log, self.logs_interface))
    return self.interface.show('TWIN', voices)


  def device_sleep(self):
    self.interface.sleep_screen()
    machine.deepsleep()

  def start(self):
    display_enabled = self.config['interface']['display']['enabled']
    button_enabled = self.config['interface']['button']['enabled']
    default_attack = self.config['attacks']['default']
    if display_enabled and button_enabled:
      self.loop()
    elif default_attack == 'beacon_spammer':
      self.attacks.start_beacon_spammer()
    elif default_attack == 'captive_portal':
      self.attacks.start_captive_portal()
    elif default_attack == 'evil_twin':
      target_essid = self.config['attacks']['evil_twin']['default_essid']
      self.attacks.start_evil_twin(target_essid)

  def loop(self, callback = None):
    while True:
      if not callback:
        callback = self.main_interface
      callback = callback()
コード例 #8
0
class Classifier:
    """
    
    """
    def __init__(self,
                 ds_name,
                 ds_path,
                 lr,
                 iterations,
                 batch_size,
                 print_freq,
                 k,
                 eps,
                 is_normalized,
                 adv_momentum,
                 store_adv=None,
                 load_adv_dir=None,
                 load_adv_name=None,
                 load_dir=None,
                 load_name=None,
                 save_dir=None):

        self.data_processor = Preprocessor(ds_name, ds_path, is_normalized)

        # Load Data
        self.train_data, self.test_data, self.N_train, self.N_test = self.data_processor.datasets(
        )
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=batch_size,
                                       shuffle=True)
        self.test_loader = DataLoader(self.test_data, batch_size=batch_size)

        # Other Variables
        self.save_dir = save_dir
        self.store_adv = store_adv

        # Set Model Hyperparameters
        self.learning_rate = lr
        self.iterations = iterations
        self.print_freq = print_freq
        self.cuda = torch.cuda.is_available()

        # Load Model to Conduct Adversarial Training
        adversarial_model = self.load_model(self.cuda, load_adv_dir,
                                            load_adv_name, TEST)
        self.adversarial_generator = Attacks(adversarial_model, eps,
                                             self.N_train, self.N_test,
                                             self.data_processor.get_const(),
                                             adv_momentum, is_normalized,
                                             store_adv)

        # Load Target Model
        self.target_model = self.load_model(self.cuda, load_dir, load_name,
                                            TEST)

        # Load Denoiser
        self.denoiser = Denoiser(x_h=32, x_w=32)
        self.denoiser = self.denoiser.cuda()

#         sys.exit()

    def load_model(self, is_cuda, load_dir=None, load_name=None, mode=None):
        """ Return WideResNet model, in gpu if applicable, and with provided checkpoint if given"""
        model = WideResNet(depth=28,
                           num_classes=10,
                           widen_factor=10,
                           dropRate=0.0)

        # Send to GPU if any
        if is_cuda:
            model = torch.nn.DataParallel(model).cuda()
            print(">>> SENDING MODEL TO GPU...")

        # Load checkpoint
        if load_dir and load_name and mode == TEST:
            model = self.load_checkpoint(model, load_dir, load_name)
            print(">>> LOADING CHECKPOINT:", load_dir)

        return model

    def grad_step(self, x_batch, y_batch):
        """ Performs a step during training. """
        # Compute output for example
        logits = self.target_model(x_batch)
        loss = self.target_model.module.loss(logits, y_batch)

        return logits, loss

        # Update Mean loss for current iteration
#         losses.update(loss.item(), x_batch.size(0))
#         prec1 = accuracy(logits.data, y_batch, k=k)
#         top1.update(prec1.item(), x_batch.size(0))

#         # compute gradient and do SGD step
#         loss.backward()
#         optimizer.step()

#         # Set grads to zero for new iter
#         optimizer.zero_grad()

    def no_grad_step(self, x_batch, y_batch):
        """ Performs a step during testing."""
        logits, loss = None, None
        with torch.no_grad():
            logits = self.target_model(x_batch)
            loss = self.target_model.module.loss(logits, y_batch)

        # Update Mean loss for current iteration
#         losses.update(loss.item(), x_batch.size(0))
#         prec1 = accuracy(logits.data, y_batch, k=k)
#         top1.update(prec1.item(), x_batch.size(0))

        return logits, loss

    def train(self, train_max_iter=1, test_max_iter=1):

        self.target_model.eval()

        train_loss_hist = []
        train_acc_hist = []
        test_loss_hist = []
        test_acc_hist = []

        best_pred = 0.0

        end = time.time()

        for itr in range(self.iterations):

            #             self.model.train()

            optimizer = optim.Adam(self.denoiser.parameters(),
                                   lr=self.learning_rate)

            losses = AverageMeter()
            batch_time = AverageMeter()
            top1 = AverageMeter()

            x_adv = None
            stored = self.adversarial_generator.get_stored(mode='train')

            for i, (x, y) in enumerate(self.train_loader):

                x = x.cuda()
                y = y.cuda()

                # FGSM
                if not stored:
                    # 1. Generate Predictions on batch
                    logits, _ = self.no_grad_step(x, y)
                    y_pred = torch.argmax(logits, dim=1)

                    # 2. Generate adversaries with y_pred (avoids 'label leak' problem)
                    x_adv, _ = self.adversarial_generator.fast_pgd(
                        x, y_pred, train_max_iter, mode='train')
                    self.adversarial_generator.retain_adversaries(x_adv,
                                                                  y,
                                                                  mode='train')
                else:
                    x_adv, y_adv = self.adversarial_generator.fast_pgd(
                        x, y, train_max_iter, mode='train')

                # 3. Compute denoised image. Need to check this...
                noise = self.denoiser.forward(x_adv)
                x_smooth = x_adv + noise

                #                 print(noise)

                # 4. Get logits from smooth and denoised image
                logits_smooth, _ = self.grad_step(x_smooth, y)
                logits_org, _ = self.grad_step(x, y)

                # 5. Compute loss
                loss = torch.sum(
                    torch.abs(logits_smooth - logits_org)) / x.size(0)

                # 6. Update Mean loss for current iteration
                losses.update(loss.item(), x.size(0))
                prec1 = accuracy(logits_smooth.data, y)
                top1.update(prec1.item(), x.size(0))

                # compute gradient and do SGD step
                loss.backward()
                optimizer.step()

                # Set grads to zero for new iter
                optimizer.zero_grad()

                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.print_freq == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                              itr,
                              i,
                              len(self.train_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1))

            # Evaluate on validation set


#             test_loss, test_prec1 = self.test(self.test_loader, test_max_iter)

#             train_loss_hist.append(losses.avg)
#             train_acc_hist.append(top1.avg)
#             test_loss_hist.append(test_loss)
#             test_acc_hist.append(test_prec1)

# Store best model
#             is_best = best_pred < test_prec1
#             self.save_checkpoint(is_best, (itr+1), self.model.state_dict(), self.save_dir)
#             if is_best:
#                 best_pred = test_prec1

# Adversarial examples generated on the first iteration. Store them if re-using same iteration ones.
            if self.store_adv:
                self.adversarial_generator.set_stored('train', True)

        return (train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist)

    def test(self, batch_loader, test_max_iter=1):
        #         self.model.eval()

        losses = AverageMeter()
        batch_time = AverageMeter()
        top1 = AverageMeter()

        end = time.time()

        for i, (x, y) in enumerate(batch_loader):

            x = x.cuda()
            y = y.cuda()

            # Test on adversarial
            self.test_step(x, y, losses, top1)

            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                print('Epoch: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i,
                          len(batch_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1))

        print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
        return (losses.avg, top1.avg)

    def save_checkpoint(self,
                        is_best,
                        epoch,
                        state,
                        save_dir,
                        base_name="chkpt"):
        """Saves checkpoint to disk"""
        directory = save_dir
        filename = base_name + ".pth.tar"
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = directory + filename
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            directory + base_name + '__model_best.pth.tar')

    def load_checkpoint(self, model, load_dir, load_name):
        """Load checkpoint from disk"""
        filepath = load_dir + load_name
        if os.path.exists(filepath):
            state_dict = torch.load(filepath)
            model.load_state_dict(state_dict)
            return model

        print("Failed to load model. Exiting...")
        sys.exit(1)