Exemplo n.º 1
0
    def __init__(self, args):

        self.args = args

        transformer = T.Compose([
            T.ToTensor()
        ])
        kwargs = {'num_workers': 4, 'pin_memory': True}

        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root, train=False, transform=transformer),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()

        # Loading model
        assert self.args.restore is not None

        model_data = torch.load(self.args.restore)
        self.model.load_state_dict(model_data['model'])
        self.model.eval()

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)
        self.save_path = self.args.save_path
Exemplo n.º 2
0
    def __init__(self, args):

        self.args = args

        transformer = T.Compose([T.ToTensor()])
        kwargs = {'num_workers': 4, 'pin_memory': True}

        train_set = datasets.CIFAR10(args.data_root,
                                     train=True,
                                     transform=transformer,
                                     download=True)
        self.train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        self.train_samples_np = train_set.data.astype(np.float32)
        self.train_samples_np = self.train_samples_np.transpose(0, 3, 1, 2)
        self.train_samples_np = np.reshape(
            self.train_samples_np, (self.train_samples_np.shape[0], -1))

        self.train_samples_np = self.train_samples_np / 255.0
        self.labels_np = np.array(train_set.targets)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()

        # Loading model
        assert self.args.restore is not None

        model_data = torch.load(self.args.restore)
        self.model.load_state_dict(model_data['model'])
        self.model.eval()

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)
        self.save_path = self.args.save_path
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        transform_train = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomCrop(32),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        transform_test = T.Compose([T.ToTensor()])

        kwargs = {'num_workers': 4, 'pin_memory': True}

        train_dataset = datasets.CIFAR10(args.data_root,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        self.train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=False,
                             transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[70, 90, 100], gamma=0.2)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0

        num_samples = len(train_dataset)
        self.epsilon_memory = torch.FloatTensor(num_samples).zero_().cuda()

        # resume from checkpoint
        ckpt_path = osp.join(args.save_path, 'checkpoint.pth')
        if osp.exists(ckpt_path):
            self._load_from_checkpoint(ckpt_path)
        elif args.restore:
            self._load_from_checkpoint(args.restore)

        cudnn.benchmark = True
        self.attacker = PGDAttackerAdaptive()
        self.attacker_test = PGDAttacker(args.attack_eps)
Exemplo n.º 4
0
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        transform_train = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomCrop(32),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        transform_test = T.Compose([T.ToTensor()])

        kwargs = {'num_workers': 8, 'pin_memory': True}

        self.train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=True,
                             download=True,
                             transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=False,
                             transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=34, width=1, num_classes=10)

        self.spbn_flag = args.spbn
        if self.spbn_flag:
            print("SPBN training!")
            self.model = models.convert_splitbn_model(self.model,
                                                      momentum=0.5).cuda()
        else:
            self.model.cuda()

        self.lambda_ = 0.9

        # spbn_1 = 0.7 adv momentum = 0.1
        # spbn_2 = 0.7, adv_momentum = 0.01
        # spbn_3 = 0.9, adv_momentum = 0.01
        # spbn_4 = 0.9, adv_momentum = 0.5

        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[60, 120, 160], gamma=0.1)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0
        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)

        # resume from checkpoint
        if args.resume:
            ckpt_path = os.path.join(args.save_path, 'checkpoint.pth')
            if os.path.exists(ckpt_path):
                self._load_from_checkpoint(ckpt_path)
            elif args.restore:
                self._load_from_checkpoint(args.restore)
Exemplo n.º 5
0
class Trainer:
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        transform_train = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomCrop(32),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        transform_test = T.Compose([T.ToTensor()])

        kwargs = {'num_workers': 8, 'pin_memory': True}

        self.train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=True,
                             download=True,
                             transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=False,
                             transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=34, width=1, num_classes=10)

        self.spbn_flag = args.spbn
        if self.spbn_flag:
            print("SPBN training!")
            self.model = models.convert_splitbn_model(self.model,
                                                      momentum=0.5).cuda()
        else:
            self.model.cuda()

        self.lambda_ = 0.9

        # spbn_1 = 0.7 adv momentum = 0.1
        # spbn_2 = 0.7, adv_momentum = 0.01
        # spbn_3 = 0.9, adv_momentum = 0.01
        # spbn_4 = 0.9, adv_momentum = 0.5

        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[60, 120, 160], gamma=0.1)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0
        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)

        # resume from checkpoint
        if args.resume:
            ckpt_path = os.path.join(args.save_path, 'checkpoint.pth')
            if os.path.exists(ckpt_path):
                self._load_from_checkpoint(ckpt_path)
            elif args.restore:
                self._load_from_checkpoint(args.restore)

    def _log(self, message):
        print(message)
        f = open(os.path.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def _load_from_checkpoint(self, ckpt_path):
        print('Loading model from {} ...'.format(ckpt_path))
        model_data = torch.load(ckpt_path)
        self.model.load_state_dict(model_data['model'])
        self.optimizer.load_state_dict(model_data['optimizer'])
        self.lr_scheduler.load_state_dict(model_data['lr_scheduler'])
        self.epoch = model_data['epoch'] + 1
        print('Model loaded successfully')

    def _save_checkpoint(self, best=True):
        self.model.eval()
        model_data = dict()
        model_data['model'] = self.model.state_dict()
        model_data['optimizer'] = self.optimizer.state_dict()
        model_data['lr_scheduler'] = self.lr_scheduler.state_dict()
        model_data['epoch'] = self.epoch
        if best:
            torch.save(model_data, os.path.join(self.save_path, 'best.pth'))
        else:
            torch.save(model_data,
                       os.path.join(self.save_path, 'checkpoint.pth'))

    def train(self):

        losses = utils.AverageMeter()
        # summary writer
        log_dir = self.save_path + '/training_log'
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        self.writer = SummaryWriter(log_dir)
        best_acc = 0
        while self.epoch < self.args.nepochs:
            self.model.train()
            correct = 0
            total = 0
            start_time = time.time()
            tq = tqdm(enumerate(self.train_loader),
                      total=len(self.train_loader),
                      leave=True)
            for i, data in tq:
                input, target = data
                target = target.cuda(non_blocking=True)
                input = input.cuda(non_blocking=True)
                adv_input = self.attacker.attack(input,
                                                 target,
                                                 self.model,
                                                 self.args.attack_steps,
                                                 self.args.attack_lr,
                                                 random_init=True,
                                                 target=None)

                # compute output
                self.optimizer.zero_grad()

                if self.spbn_flag:
                    concat = torch.cat((input, adv_input), dim=0)
                    logits = self.model(concat)
                    clean_logits, adv_logits = torch.split(logits,
                                                           target.size(0),
                                                           dim=0)

                    adv_loss = F.cross_entropy(adv_logits, target)
                    clean_loss = F.cross_entropy(clean_logits, target)

                    loss = self.lambda_ * adv_loss + (
                        1 - self.lambda_) * clean_loss
                else:
                    clean_logits = self.model(input)
                    loss = F.cross_entropy(clean_logits, target)
                    mean, std = models.print_mean_std(self.model)

                loss.backward()
                self.optimizer.step()

                _, pred = torch.max(clean_logits, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)
                acc = (float(correct) / total) * 100

                # measure accuracy and record loss
                losses.update(loss.data.item(), input.size(0))
                message = 'Epoch {}/{}, Loss: {:.4f}, Accuracy: {:.4f}'.format(
                    self.epoch, self.args.nepochs, loss.item(), acc)
                tq.set_description(message)

                # writing log in Tensorboard
                if self.spbn_flag:
                    self.writer.add_scalars(
                        'Adv_training/loss', {
                            'clean_loss': clean_loss.item(),
                            'adv_loss': adv_loss.item(),
                            'entire_loss': loss.item()
                        })
                    self.writer.add_scalar('Adv_training/Acc', acc)
                else:
                    self.writer.add_scalar('Adv_training/loss', loss.item())
                    self.writer.add_scalar('Adv_training/Acc', acc)

            self.epoch += 1
            self.lr_scheduler.step()
            end_time = time.time()
            batch_time = end_time - start_time
            message = 'Epoch {}, Time {}, Loss: {}, Accuracy: {}'.format(
                self.epoch, batch_time, loss.item(), acc)
            self._log(message)
            self._save_checkpoint()

            # Evaluation
            if self.epoch % 5 == 0:
                nat_acc = self.eval()
                adv_acc = self.eval_adversarial()

                if adv_acc > best_acc:
                    print("Saving..")
                    self._save_checkpoint(best=True)
                    best_acc = adv_acc
                self._log('Natural accuracy: {}'.format(nat_acc))
                self._log('Adv accuracy: {}'.format(adv_acc))

            self._save_checkpoint(best=False)

    def eval(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy

    def eval_adversarial(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            input = self.attacker.attack(input,
                                         target,
                                         self.model,
                                         self.args.attack_steps,
                                         self.args.attack_lr,
                                         random_init=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy
class TrainerAdaptive:
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        transform_train = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomCrop(32),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        transform_test = T.Compose([T.ToTensor()])

        kwargs = {'num_workers': 4, 'pin_memory': True}

        train_dataset = datasets.CIFAR10(args.data_root,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        self.train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root,
                             train=False,
                             transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[70, 90, 100], gamma=0.2)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0

        num_samples = len(train_dataset)
        self.epsilon_memory = torch.FloatTensor(num_samples).zero_().cuda()

        # resume from checkpoint
        ckpt_path = osp.join(args.save_path, 'checkpoint.pth')
        if osp.exists(ckpt_path):
            self._load_from_checkpoint(ckpt_path)
        elif args.restore:
            self._load_from_checkpoint(args.restore)

        cudnn.benchmark = True
        self.attacker = PGDAttackerAdaptive()
        self.attacker_test = PGDAttacker(args.attack_eps)

    def _log(self, message):
        print(message)
        f = open(osp.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def _load_from_checkpoint(self, ckpt_path):
        print('Loading model from {} ...'.format(ckpt_path))
        model_data = torch.load(ckpt_path)
        self.model.load_state_dict(model_data['model'])
        self.optimizer.load_state_dict(model_data['optimizer'])
        self.lr_scheduler.load_state_dict(model_data['lr_scheduler'])
        self.epoch = model_data['epoch'] + 1
        print('Model loaded successfully')

        eps_memory = model_data['mem']
        self.epsilon_memory = eps_memory
        self.epsilon_memory = self.epsilon_memory.cuda()

    def _save_checkpoint(self, model_name='checkpoint.pth'):
        self.model.eval()
        model_data = dict()
        model_data['model'] = self.model.state_dict()
        model_data['optimizer'] = self.optimizer.state_dict()
        model_data['lr_scheduler'] = self.lr_scheduler.state_dict()
        model_data['epoch'] = self.epoch
        model_data['mem'] = self.epsilon_memory
        torch.save(model_data, osp.join(self.save_path, model_name))

    def epsilon_select(self, input, target, indices):
        # self.model.eval()
        with torch.no_grad():
            logits = self.model(input)
            _, pred = torch.max(logits, dim=1)
            correct_preds_clean = (pred == target).float()

        if self.epoch < self.args.warmup:
            epsilon = torch.zeros(input.size(0)).fill_(
                self.args.attack_eps).cuda()
            epsilon = epsilon * correct_preds_clean
        else:
            epsilon_prev = self.epsilon_memory[indices]
            epsilon_low = epsilon_prev - self.args.gamma
            epsilon_cur = epsilon_prev
            epsilon_high = epsilon_prev + self.args.gamma
            attack_lr_cur = torch.clamp(epsilon_cur /
                                        (0.5 * self.args.attack_steps),
                                        min=self.args.attack_lr)
            attack_lr_high = torch.clamp(epsilon_high /
                                         (0.5 * self.args.attack_steps),
                                         min=self.args.attack_lr)

            input_cur = self.attacker.attack(input,
                                             target,
                                             self.model,
                                             self.args.attack_steps,
                                             attack_lr_cur,
                                             epsilon_cur,
                                             random_init=True,
                                             target=None)
            input_high = self.attacker.attack(input,
                                              target,
                                              self.model,
                                              self.args.attack_steps,
                                              attack_lr_high,
                                              epsilon_high,
                                              random_init=True,
                                              target=None)

            with torch.no_grad():
                logits_cur = self.model(input_cur)
                logits_high = self.model(input_high)
                _, logits_cur = torch.max(logits_cur, dim=1)
                _, logits_high = torch.max(logits_high, dim=1)

                pred_cur = (logits_cur == target).float()
                pred_high = (logits_high == target).float()

                epsilon = pred_high * epsilon_high + (1 - pred_high) * pred_cur * epsilon_cur + \
                          (1 - pred_high) * (1 - pred_cur) * epsilon_low
                epsilon = epsilon * correct_preds_clean
                epsilon = torch.clamp(epsilon, min=0)
                epsilon = epsilon * self.args.beta + epsilon_prev * (
                    1 - self.args.beta)
        # Updating memory
        self.epsilon_memory[indices] = epsilon
        return epsilon

    def train(self):

        losses = utils.AverageMeter()

        while self.epoch < self.args.nepochs:
            self.model.train()
            correct = 0
            total = 0
            start_time = time.time()

            for i, data in enumerate(self.train_loader):
                input, target, indices = data
                target = target.cuda(non_blocking=True)
                input = input.cuda(non_blocking=True)

                if self.args.alg == 'adv_training':
                    epsilon_arr = self.epsilon_select(input, target, indices)
                    attack_lr_arr = torch.clamp(epsilon_arr /
                                                (0.5 * self.args.attack_steps),
                                                min=self.args.attack_lr)
                    input = self.attacker.attack(input,
                                                 target,
                                                 self.model,
                                                 self.args.attack_steps,
                                                 attack_lr_arr,
                                                 epsilon_arr,
                                                 random_init=True,
                                                 target=None)
                    self.model.zero_grad()

                # compute output
                self.optimizer.zero_grad()
                logits = self.model(input)
                loss = F.cross_entropy(logits, target)

                loss.backward()
                self.optimizer.step()

                _, pred = torch.max(logits, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)

                # measure accuracy and record loss
                losses.update(loss.data.item(), input.size(0))

            self.epoch += 1
            self.lr_scheduler.step()
            end_time = time.time()
            batch_time = end_time - start_time

            acc = (float(correct) / total) * 100
            message = 'Epoch {}, Time {}, Loss: {}, Accuracy: {}'.format(
                self.epoch, batch_time, loss.item(), acc)
            self._log(message)
            self._save_checkpoint()

            if self.epoch == self.args.warmup:
                self._save_checkpoint(model_name='end_of_warmup.pth')

            # Evaluation
            nat_acc = self.eval()
            adv_acc = self.eval_adversarial()
            self._log('Natural accuracy: {}'.format(nat_acc))
            self._log('Adv accuracy: {}'.format(adv_acc))

    def eval(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target, _ = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy

    def eval_adversarial(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target, _ = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            input = self.attacker_test.attack(input,
                                              target,
                                              self.model,
                                              self.args.attack_steps,
                                              self.args.attack_lr,
                                              random_init=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy
Exemplo n.º 7
0
class Trainer:
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        transform_train = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomCrop(32),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ])

        transform_test = T.Compose([
            T.ToTensor()
        ])

        kwargs = {'num_workers': 4, 'pin_memory': True}

        self.train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root, train=True, download=True,
                             transform=transform_train),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root, train=False, transform=transform_test),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()
        self.optimizer = optim.SGD(self.model.parameters(), args.lr,
                                   momentum=0.9, weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[60, 120, 160], gamma=0.2)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))


        self.save_path = args.save_path
        self.epoch = 0

        # resume from checkpoint
        ckpt_path = osp.join(args.save_path, 'checkpoint.pth')
        if osp.exists(ckpt_path):
            self._load_from_checkpoint(ckpt_path)
        elif args.restore:
            self._load_from_checkpoint(args.restore)

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)

    def _log(self, message):
        print(message)
        f = open(osp.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def _load_from_checkpoint(self, ckpt_path):
        print('Loading model from {} ...'.format(ckpt_path))
        model_data = torch.load(ckpt_path)
        self.model.load_state_dict(model_data['model'])
        self.optimizer.load_state_dict(model_data['optimizer'])
        self.lr_scheduler.load_state_dict(model_data['lr_scheduler'])
        self.epoch = model_data['epoch'] + 1
        print('Model loaded successfully')

    def _save_checkpoint(self):
        self.model.eval()
        model_data = dict()
        model_data['model'] = self.model.state_dict()
        model_data['optimizer'] = self.optimizer.state_dict()
        model_data['lr_scheduler'] = self.lr_scheduler.state_dict()
        model_data['epoch'] = self.epoch
        torch.save(model_data, osp.join(self.save_path, 'checkpoint.pth'))

    def train(self):

        losses = utils.AverageMeter()

        while self.epoch < self.args.nepochs:
            self.model.train()
            correct = 0
            total = 0
            start_time = time.time()

            for i, data in enumerate(self.train_loader):
                input, target = data
                target = target.cuda(non_blocking=True)
                input = input.cuda(non_blocking=True)

                if self.args.alg == 'adv_training':
                    input = self.attacker.attack(input, target, self.model, self.args.attack_steps, self.args.attack_lr,
                                                 random_init=True, target=None)

                # compute output
                self.optimizer.zero_grad()
                logits = self.model(input)
                loss = F.cross_entropy(logits, target)

                loss.backward()
                self.optimizer.step()

                _, pred = torch.max(logits, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)

                # measure accuracy and record loss
                losses.update(loss.data.item(), input.size(0))

            self.epoch += 1
            self.lr_scheduler.step()
            end_time = time.time()
            batch_time = end_time - start_time

            acc = (float(correct) / total) * 100
            message = 'Epoch {}, Time {}, Loss: {}, Accuracy: {}'.format(self.epoch, batch_time, loss.item(), acc)
            self._log(message)
            self._save_checkpoint()

            # Evaluation
            nat_acc = self.eval()
            adv_acc = self.eval_adversarial()
            self._log('Natural accuracy: {}'.format(nat_acc))
            self._log('Adv accuracy: {}'.format(adv_acc))

    def eval(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy

    def eval_adversarial(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            input = self.attacker.attack(input, target, self.model, self.args.attack_steps, self.args.attack_lr,
                                         random_init=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy
class Visualizer:
    def __init__(self, args):

        self.args = args

        transformer = T.Compose([T.ToTensor()])
        kwargs = {'num_workers': 16, 'pin_memory': True}

        train_set = datasets.CIFAR10(args.data_root,
                                     train=True,
                                     transform=transformer,
                                     download=True)
        self.train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        self.train_samples_np = train_set.train_data.astype(np.float32)
        self.train_samples_np = self.train_samples_np.transpose(0, 3, 1, 2)
        self.train_samples_np = np.reshape(
            self.train_samples_np, (self.train_samples_np.shape[0], -1))

        self.train_samples_np = self.train_samples_np / 255.0
        self.labels_np = np.array(train_set.train_labels)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=34, width=1, num_classes=10).cuda()
        self.model = models.convert_splitbn_model(self.model).cuda()

        # Loading model
        assert self.args.restore is not None

        model_data = torch.load(self.args.restore)
        print(self.args.restore)
        self.model.load_state_dict(model_data['model'])
        self.model.eval()

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)
        self.save_path = self.args.save_path

    def _log(self, message):
        print(message)
        f = open(osp.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def get_nn(self, inputs, labels):

        samples_nn = []
        labels_nn = []

        for i in range(inputs.size(0)):
            img_query = inputs[i].cpu().numpy()
            img_query = img_query.astype(np.float32)
            img_query = np.reshape(img_query, (1, -1))

            label_query = labels[i].cpu().numpy()

            valid_indices = (self.labels_np != label_query).nonzero()[0]
            valid_samples = self.train_samples_np[valid_indices]

            dist_mat = scipy.spatial.distance.cdist(valid_samples, img_query)
            dist_mat = np.reshape(dist_mat, (-1))
            min_ind = np.argmin(dist_mat)
            sample_ind = valid_indices[min_ind]
            samples_nn.append(self.train_samples_np[sample_ind])
            labels_nn.append(self.labels_np[sample_ind])

        samples_nn = np.array(samples_nn)
        samples_nn = np.reshape(samples_nn, (samples_nn.shape[0], 3, 32, 32))
        labels_nn = np.array(labels_nn)

        return samples_nn, labels_nn

    def visualize(self):
        """
        Module to compute and visualize adversarial perturbations
        """

        num_vis = 5
        input_list = []
        input_adv_list = []
        input_nn_list = []

        count = 0
        for (i, batch) in enumerate(self.train_loader):
            imgs, labels = batch
            input = imgs.cuda()
            labels = labels.cuda()

            # Random targets
            # target = labels + torch.randint(low=1, high=10, size=labels.size()).cuda()
            # target = torch.fmod(target, 10)

            # Nearest neighbor targets
            inputs_nn, labels_nn = self.get_nn(imgs, labels)
            labels_nn = torch.from_numpy(labels_nn).long().cuda()
            input_nn_list.append(inputs_nn)

            input_adv = self.attacker.attack(input,
                                             labels,
                                             self.model,
                                             self.args.attack_steps,
                                             self.args.attack_lr,
                                             random_init=True,
                                             target=labels_nn)

            input_np = input.cpu().numpy()
            input_adv_np = input_adv.cpu().numpy()

            input_list.append(input_np)
            input_adv_list.append(input_adv_np)

            count += input_np.shape[0]

            if count > num_vis:
                break

        input_list = np.vstack(input_list)
        input_adv_list = np.vstack(input_adv_list)
        input_nn_list = np.vstack(input_nn_list)

        input_list = input_list.transpose(0, 2, 3, 1)
        input_list = (input_list * 255.0).astype(np.uint8)
        input_adv_list = input_adv_list.transpose(0, 2, 3, 1)
        input_adv_list = (input_adv_list * 255.0).astype(np.uint8)
        input_nn_list = input_nn_list.transpose(0, 2, 3, 1)
        input_nn_list = (input_nn_list * 255.0).astype(np.uint8)

        ## Generating visualization
        fig, axs = plt.subplots(3, num_vis, figsize=(15, 7))

        for i in range(num_vis):
            axs[0][i].imshow(input_list[i])
            axs[0][i].axis('off')

        for i in range(num_vis):
            axs[1][i].imshow(input_adv_list[i])
            axs[1][i].axis('off')

        for i in range(num_vis):
            axs[2][i].imshow(input_nn_list[i])
            axs[2][i].axis('off')

        plt.savefig('{}/adv.png'.format(self.args.save_path))
Exemplo n.º 9
0
    def eval_worker(self, attack):
        correct = 0
        total = 0
        err = 0
        print('Evaluating {}'.format(attack))

        attack_eps = 8.0 / 255
        if attack == 'PGD':
            attacker = PGDAttacker(attack_eps)
        else:
            attacker = Attackers(self.model, 10, attack, attack_eps,
                                 self.args.batch_size)

        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            if attack == 'PGD':
                attack_steps = 10
                attack_lr = 2.0 / 255
                pert = attacker.attack(input,
                                       target,
                                       self.model,
                                       attack_steps,
                                       attack_lr,
                                       random_init=True)
                with torch.no_grad():
                    output = self.model(pert)

                _, pred = torch.max(output, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)
            elif attack == 'CW':
                pert = attacker.attack(input, target)
                with torch.no_grad():
                    output = self.model(pert)

                _, pred = torch.max(output, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)
            else:
                for j in range(input.size(0)):
                    inp = input[j].unsqueeze(0)
                    tar = target[j].unsqueeze(0)
                    att = attacker.attack(inp, tar)
                    if att is not None:
                        with torch.no_grad():
                            output = self.model(att)

                        _, pred = torch.max(output, dim=1)
                        if pred.item() != tar.item():
                            err += 1
                    total += 1
        print('Evaluated on {} samples'.format(total))

        if attack == 'PGD' or attack == 'CW':
            accuracy = (float(correct) / total) * 100
        else:
            accuracy = (1.0 - float(err) / total) * 100

        return accuracy
Exemplo n.º 10
0
class Evaluator:
    def __init__(self, args):

        self.args = args

        transformer = T.Compose([
            T.ToTensor()
        ])
        kwargs = {'num_workers': 4, 'pin_memory': True}

        self.val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_root, train=False, transform=transformer),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        # Create model, optimizer and scheduler
        self.model = models.WRN(depth=32, width=10, num_classes=10)
        self.model = torch.nn.DataParallel(self.model).cuda()

        # Loading model
        assert self.args.restore is not None

        model_data = torch.load(self.args.restore)
        self.model.load_state_dict(model_data['model'])
        self.model.eval()

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)
        self.save_path = self.args.save_path

    def _log(self, message):
        print(message)
        f = open(osp.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def eval(self):
        if self.args.attack_eps == 0 or self.args.attack_steps == 0:
            acc = self.eval_worker(adv_flag=False)
        else:
            acc = self.eval_worker(adv_flag=True)
        message = 'PGD-{}; Acc:{}'.format(self.args.attack_steps, acc)
        self._log(message)
        return acc

    def eval_worker(self, adv_flag=True):
        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            if adv_flag:
                input = self.attacker.attack(input, target, self.model, self.args.attack_steps, self.args.attack_lr,
                                             random_init=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)

        accuracy = (float(correct) / total) * 100
        return accuracy
Exemplo n.º 11
0
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        kwargs = {'num_workers': 4, 'pin_memory': True}
        if args.dataset == 'MNIST':
            # setup data loader
            self.train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data',
                               train=True,
                               download=True,
                               transform=T.ToTensor()),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)

            self.val_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data', train=False, transform=T.ToTensor()),
                batch_size=args.batch_size,
                shuffle=False,
                **kwargs)

            # initialize model
            torch.manual_seed(args.seed)
            self.model = models.SmallCNN()

        elif args.dataset == 'CIFAR10':
            transform_train = T.Compose([
                T.Pad(4, padding_mode='reflect'),
                T.RandomCrop(32),
                T.RandomHorizontalFlip(),
                T.ToTensor()
            ])
            transform_test = T.Compose([T.ToTensor()])

            self.train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.data_root,
                                 train=True,
                                 download=True,
                                 transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            self.val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.data_root,
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            # initialize model
            torch.manual_seed(args.seed)
            self.model = models.ResNet18()

        self.model = torch.nn.DataParallel(self.model).cuda()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[60, 120, 160], gamma=0.2)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0

        # resume from checkpoint

        ckpt_path = osp.join(self.save_path, 'checkpoint.pth')
        if osp.exists(ckpt_path):
            self._load_from_checkpoint(ckpt_path)
        elif args.restore:
            self._load_from_checkpoint(args.restore)

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)
Exemplo n.º 12
0
class Trainer:
    def __init__(self, args):

        self.args = args

        # Creating data loaders
        kwargs = {'num_workers': 4, 'pin_memory': True}
        if args.dataset == 'MNIST':
            # setup data loader
            self.train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data',
                               train=True,
                               download=True,
                               transform=T.ToTensor()),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)

            self.val_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data', train=False, transform=T.ToTensor()),
                batch_size=args.batch_size,
                shuffle=False,
                **kwargs)

            # initialize model
            torch.manual_seed(args.seed)
            self.model = models.SmallCNN()

        elif args.dataset == 'CIFAR10':
            transform_train = T.Compose([
                T.Pad(4, padding_mode='reflect'),
                T.RandomCrop(32),
                T.RandomHorizontalFlip(),
                T.ToTensor()
            ])
            transform_test = T.Compose([T.ToTensor()])

            self.train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.data_root,
                                 train=True,
                                 download=True,
                                 transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            self.val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.data_root,
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            # initialize model
            torch.manual_seed(args.seed)
            self.model = models.ResNet18()

        self.model = torch.nn.DataParallel(self.model).cuda()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   args.lr,
                                   momentum=0.9,
                                   weight_decay=args.weight_decay)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[60, 120, 160], gamma=0.2)

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        self.save_path = args.save_path
        self.epoch = 0

        # resume from checkpoint

        ckpt_path = osp.join(self.save_path, 'checkpoint.pth')
        if osp.exists(ckpt_path):
            self._load_from_checkpoint(ckpt_path)
        elif args.restore:
            self._load_from_checkpoint(args.restore)

        cudnn.benchmark = True
        self.attacker = PGDAttacker(args.attack_eps)

    def _log(self, message):
        print(message)
        f = open(osp.join(self.save_path, 'log.txt'), 'a+')
        f.write(message + '\n')
        f.close()

    def _load_from_checkpoint(self, ckpt_path):
        print('Loading model from {} ...'.format(ckpt_path))
        model_data = torch.load(ckpt_path)
        self.model.load_state_dict(model_data['model'])
        self.optimizer.load_state_dict(model_data['optimizer'])
        self.lr_scheduler.load_state_dict(model_data['lr_scheduler'])
        self.epoch = model_data['epoch'] + 1
        print('Model loaded successfully')

    def _save_checkpoint(self):
        self.model.eval()
        model_data = dict()
        model_data['model'] = self.model.state_dict()
        model_data['optimizer'] = self.optimizer.state_dict()
        model_data['lr_scheduler'] = self.lr_scheduler.state_dict()
        model_data['epoch'] = self.epoch
        torch.save(model_data, osp.join(self.save_path, 'checkpoint.pth'))
        print("Saving to: ", self.save_path)

    def train(self):

        losses = utils.AverageMeter()

        while self.epoch < self.args.nepochs:
            self.model.train()
            correct = 0
            total = 0
            start_time = time.time()

            for i, data in enumerate(self.train_loader):
                input, target = data
                target = target.cuda(non_blocking=True)
                input = input.cuda(non_blocking=True)

                if self.args.alg == 'adv_training':
                    input = self.attacker.attack(input,
                                                 target,
                                                 self.model,
                                                 self.args.attack_steps,
                                                 self.args.attack_lr,
                                                 random_init=True,
                                                 target=None,
                                                 mode=self.args.mode,
                                                 alpha=self.args.alpha)

                # compute output
                self.optimizer.zero_grad()
                logits = self.model(input)

                if self.args.mode == 'margin':
                    proba = F.softmax(logits, dim=-1)
                    proba_y = proba[list(range(input.size(0))), target]
                    proba_tops, ids_tops = torch.topk(proba, 2)

                    proba_top1 = proba_tops[:, 0]
                    proba_top2 = proba_tops[:, 1]

                    ids_top1 = ids_tops[:, 0]
                    ids_top2 = ids_tops[:, 1]

                    top1_t = (ids_top1.long() != target).float() * proba_top1
                    top2_t = (ids_top1.long() == target).float() * proba_top2

                    proba_t = top1_t + top2_t
                    margin = proba_y - proba_t
                    exp_coe = torch.exp(-margin * self.args.alpha)

                    loss = F.cross_entropy(logits, target, reduction='none')
                    loss = torch.mean(loss * exp_coe)

                elif self.args.mode == 'baseline':
                    loss = F.cross_entropy(logits, target)

                loss.backward()
                self.optimizer.step()

                _, pred = torch.max(logits, dim=1)
                correct += (pred == target).sum()
                total += target.size(0)

                # measure accuracy and record loss
                losses.update(loss.data.item(), input.size(0))
                # break

            self.epoch += 1
            self.lr_scheduler.step()
            end_time = time.time()
            batch_time = end_time - start_time

            acc = (float(correct) / total) * 100
            message = 'Epoch {}, Time {}, Loss: {}, Accuracy: {}'.format(
                self.epoch, batch_time, loss.item(), acc)
            self._log(message)
            self._save_checkpoint()

            # Evaluation
            nat_acc = self.eval()
            adv_acc = self.eval_adversarial()
            self._log('Natural accuracy: {}'.format(nat_acc))
            self._log('Adv accuracy: {}'.format(adv_acc))
            # break

    def eval(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)
            # break

        accuracy = (float(correct) / total) * 100
        return accuracy

    def eval_adversarial(self):
        self.model.eval()

        correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            input, target = data
            target = target.cuda(non_blocking=True)
            input = input.cuda(non_blocking=True)

            input = self.attacker.attack(input,
                                         target,
                                         self.model,
                                         self.args.attack_steps,
                                         self.args.attack_lr,
                                         random_init=True)

            # compute output
            with torch.no_grad():
                output = self.model(input)

            _, pred = torch.max(output, dim=1)
            correct += (pred == target).sum()
            total += target.size(0)
            # break

        accuracy = (float(correct) / total) * 100
        return accuracy