def test(epoch, net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    xce = 0.

    iterator = tqdm(testloader, ncols=0, leave=False)
    # x_adv = torch.load('x_adv.pt')['x_adv']
    # print(x_adv.size())
    # i = -1
    for batch_idx, (inputs, targets) in enumerate(iterator):
        # i += 1
        start_time = time.time()
        inputs, targets = inputs.to(device), targets.to(device)
        pert_inputs = inputs.detach()
        # pert_inputs, targets = x_adv[i*args.batch_size_test:np.minimum((i+1)*args.batch_size_test, 10000)].to(device), targets.to(device)

        outputs, _, _, pert_inputs, pert_i = net(pert_inputs,
                                                 targets,
                                                 batch_idx=batch_idx)

        xce_batch = torch.sum(-utils.one_hot_tensor(targets, 10, device) *
                              F.log_softmax(outputs)).item()

        loss = criterion(outputs, targets)
        test_loss += loss.item()

        duration = time.time() - start_time

        _, predicted = outputs.max(1)
        batch_size = targets.size(0)
        total += batch_size
        correct_num = predicted.eq(targets).sum().item()
        correct += correct_num
        iterator.set_description(
            str(predicted.eq(targets).sum().item() / targets.size(0)))

        xce += xce_batch

        if batch_idx % args.log_step == 0:
            print(
                "step %d, duration %.2f, test  acc %.2f, avg-acc %.2f, loss %.2f"
                % (batch_idx, duration, 100. * correct_num / batch_size,
                   100. * correct / total, test_loss / total))

    acc = 100. * correct / total
    print('Val acc:', acc)
    xce = xce / total
    print('xce : ', xce)
    return acc
Beispiel #2
0
def attack_inversion(model, inputs, targets, config):
    step_size = config['step_size']
    epsilon = config['epsilon']
    num_steps = config['num_steps']
    ls_factor = config['ls_factor']
    model.eval()
    inv_idx = torch.arange(inputs.size(0) - 1, -1, -1).long()
    x = inputs.detach()
    x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon)

    logits_pred_nat, fea_nat = model(inputs[inv_idx])
    fea_nat = fea_nat.detach()
    num_classes = logits_pred_nat.size(-1)

    for i in range(num_steps):
        x.requires_grad_()
        zero_gradients(x)
        if x.grad is not None:
            x.grad.data.fill_(0)
        logits_pred, fea = model(x)
        #inver_loss = ot.cost_matrix_cos(fea, fea_nat)
        inver_loss = ot.pair_cos_dist(fea, fea_nat)
        #inver_loss = torch.div(torch.norm(fea - fea_nat, dim=1), torch.norm(fea_nat, dim=1))
        adv_loss = inver_loss.mean()
        adv_loss.backward()
        x_adv = x.data - step_size * torch.sign(x.grad.data)
        x_adv = torch.min(torch.max(x_adv, inputs-epsilon), inputs+epsilon)
        x_adv = torch.clamp(x_adv, -1.0, 1.0)
        x = Variable(x_adv)

    targets_one_hot = one_hot_tensor(targets, num_classes, device)
    # if adapt label smooth
    targets_one_hot_inv = targets_one_hot[inv_idx]
    soft_targets = (1 - targets_one_hot_inv) / (num_classes - 1)
    soft_targets = (1 - ls_factor) * targets_one_hot + ls_factor * soft_targets
    # if not adapt label smooth
    # soft_targets = utils.label_smoothing(targets_one_hot, targets_one_hot.size(1), ls_factor)
    return x, soft_targets
Beispiel #3
0
    def forward(self,
                inputs,
                targets,
                attack=True,
                targeted_label=-1,
                batch_idx=0):

        if not attack:
            outputs, _ = self.basic_net(inputs)
            return outputs, None
        if self.box_type == 'white':
            aux_net = pickle.loads(pickle.dumps(self.basic_net))
        elif self.box_type == 'black':
            assert self.attack_net is not None, "should provide an additional net in black-box case"
            aux_net = pickle.loads(pickle.dumps(self.basic_net))

        aux_net.eval()
        batch_size = inputs.size(0)
        m = batch_size
        n = batch_size

        # logits = aux_net(inputs)[0]
        # num_classes = logits.size(1)

        # outputs = aux_net(inputs)[0]
        # targets_prob = F.softmax(outputs.float(), dim=1)
        # y_tensor_adv = targets
        # step_sign = 1.0

        x = inputs.detach()

        # x_org = x.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)

        if self.train_flag:
            self.basic_net.train()
        else:
            self.basic_net.eval()

        logits_pred_nat, fea_nat = aux_net(inputs)

        num_classes = logits_pred_nat.size(1)
        y_gt = one_hot_tensor(targets, num_classes, device)

        loss_ce = softCrossEntropy()

        iter_num = self.num_steps

        for i in range(iter_num):
            x.requires_grad_()
            zero_gradients(x)
            if x.grad is not None:
                x.grad.data.fill_(0)

            logits_pred, fea = aux_net(x)

            ot_loss = ot.sinkhorn_loss_joint_IPOT(1, 0.00, logits_pred_nat,
                                                  logits_pred, None, None,
                                                  0.01, m, n)

            aux_net.zero_grad()
            adv_loss = ot_loss
            adv_loss.backward(retain_graph=True)
            x_adv = x.data + self.step_size * torch.sign(x.grad.data)
            x_adv = torch.min(torch.max(x_adv, inputs - self.epsilon),
                              inputs + self.epsilon)
            x_adv = torch.clamp(x_adv, -1.0, 1.0)
            x = Variable(x_adv)

            logits_pred, fea = self.basic_net(x)
            self.basic_net.zero_grad()

            y_sm = utils.label_smoothing(y_gt, y_gt.size(1), self.ls_factor)

            adv_loss = loss_ce(logits_pred, y_sm.detach())

        return logits_pred, adv_loss
def train_fun(epoch, net):
    print('\nEpoch: %d' % epoch)
    net.train()

    train_loss = 0
    correct = 0
    total = 0

    # update learning rate
    if epoch < args.decay_epoch1:
        lr = args.lr
    elif epoch < args.decay_epoch2:
        lr = args.lr * args.decay_rate
    else:
        lr = args.lr * args.decay_rate * args.decay_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    def get_acc(outputs, targets):
        _, predicted = outputs.max(1)
        total = targets.size(0)
        correct = predicted.eq(targets).sum().item()
        acc = 1.0 * correct / total
        return acc

    iterator = tqdm(trainloader, ncols=0, leave=False)
    # iterator = trainloader

    for batch_idx, (inputs, targets) in enumerate(iterator):
        # for tuples in enumerate(iterator):
        start_time = time.time()
        if args.dataset == 'cifar_aug':

            inputs_aug, targets_aug = next(iter(trainloader_aug))
            indices = np.random.permutation(targets_aug.size()[0])
            inputs_aug = inputs_aug[indices]
            inputs_orig, targets_orig = inputs.detach(), targets.detach()

            inputs[:args.batch_size_train //
                   5] = inputs_aug[:args.batch_size_train // 5]
            # targets = np.eye(args.batch_size_train)[targets]
            targets = one_hot_tensor(targets, 10, device)
            targets[:args.batch_size_train // 5, :] = 0.1

        inputs, targets = inputs.to(device), targets.to(device)

        adv_acc = 0

        optimizer.zero_grad()

        # forward feature_scatter
        if (args.adv_mode.lower() == 'feature_scatter'
                or args.adv_mode.lower() == 'lip_reg'
                or args.adv_mode.lower() == 'trades'):
            outputs, loss_fs, flag_out, _, diff_loss = net(
                inputs.detach(), targets)
            loss = loss_fs
            optimizer.zero_grad()
        elif args.adv_mode.lower() == 'madry':
            # forward madry
            outputs, _, _, pert_inputs, pert_i, y_train = net(inputs, targets)
            loss = soft_xent_loss(outputs, y_train)
            # loss = soft_xent_loss(outputs * 0.5, y_train) # temperturing
            #loss = F.cross_entropy(outputs, targets)
            optimizer.zero_grad()

        elif args.adv_mode.lower() == 'vertex':
            # forward vertex
            outputs, _, _, _, _, y_vertex = net(inputs, targets)
            # outputs, _, _, _, _, y_vertex = net(inputs, targets, epoch = (epoch+1) / args.max_epoch)
            loss = soft_xent_loss(outputs, y_vertex)
            optimizer.zero_grad()
        elif args.adv_mode.lower() == 'vertex_pgd':
            # forward vertex
            outputs, _, _, _, _, y_vertex = net(inputs, targets)
            loss = soft_xent_loss(outputs, y_vertex)
            optimizer.zero_grad()
        elif args.adv_mode.lower() == 'natural':
            # forward vertex
            outputs, _, _, _, _ = net(inputs, targets)
            # loss = F.cross_entropy(basic_net(inputs.detach())[0], targets)
            loss = F.cross_entropy(outputs, targets)
            optimizer.zero_grad()
        elif args.adv_mode.lower() == 'linear':
            # forward vertex
            outputs, _, _, x_train, _ = net(inputs, targets)
            # net(inputs, targets)
            # outputs = basic_net(inputs.detach())[0]
            # loss = F.cross_entropy(outputs, targets.detach())
            outputs, loss_fs, flag_out, _, diff_loss = net(
                inputs.detach(), targets)
            loss = loss_fs
            optimizer.zero_grad()
        else:
            print('no adv_mode')
            loss = None

        loss.backward()

        optimizer.step()

        train_loss = loss.item()

        duration = time.time() - start_time

        if batch_idx % args.log_step == 0:
            if args.dataset == 'cifar_aug':
                inputs, targets = inputs_orig.to(device), targets_orig.to(
                    device)
            if adv_acc == 0:
                adv_acc = get_acc(outputs, targets)
            iterator.set_description(str(adv_acc))

            nat_outputs, _, _, _, _ = net(inputs, targets, attack=False)

            nat_acc = get_acc(nat_outputs, targets)

            print(
                "epoch %d, step %d, lr %.4f, duration %.2f, training nat acc %.2f, training adv acc %.2f, training adv loss %.4f"
                % (epoch, batch_idx, lr, duration, 100 * nat_acc,
                   100 * adv_acc, train_loss))

    if epoch % 10 == 0:
        print('Saving..')
        f_path = os.path.join(args.model_dir, ('checkpoint-%s' % epoch))
        state = {
            'net': net.state_dict(),
            # 'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir(args.model_dir):
            os.mkdir(args.model_dir)
        torch.save(state, f_path)

    if epoch >= 0:
        print('Saving latest @ epoch %s..' % (epoch))
        f_path = os.path.join(args.model_dir, 'latest')
        state = {
            'net': net.state_dict(),
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir(args.model_dir):
            os.mkdir(args.model_dir)
        torch.save(state, f_path)
    '''
Beispiel #5
0
def train_one_epoch(epoch, net):
    print('\n Training for Epoch: %d' % epoch)

    net.train()

    # learning rate schedule
    if epoch < args.decay_epoch1:
        lr = args.lr
    elif epoch < args.decay_epoch2:
        lr = args.lr * args.decay_rate
    else:
        lr = args.lr * args.decay_rate * args.decay_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    iterator = tqdm(trainloader, ncols=0, leave=False)
    for batch_idx, (inputs, targets) in enumerate(iterator):
        start_time = time.time()
        inputs, targets = inputs.to(device), targets.to(device)

        targets_onehot = one_hot_tensor(targets, args.num_classes, device)

        x_tilde, y_tilde = adv_interp(inputs, targets_onehot, net,
                                      args.num_classes,
                                      config_adv_interp['epsilon'],
                                      config_adv_interp['label_adv_delta'],
                                      config_adv_interp['v_min'],
                                      config_adv_interp['v_max'])

        outputs = net(x_tilde, mode='logits')
        loss = soft_xent_loss(outputs, y_tilde)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        train_loss = loss.detach().item()

        duration = time.time() - start_time
        if batch_idx % args.log_step == 0:

            adv_acc = utils.get_acc(outputs, targets)
            # natural
            net_cp = copy.deepcopy(net)
            nat_outputs = net_cp(inputs, mode='logits')
            nat_acc = utils.get_acc(nat_outputs, targets)
            print(
                "Epoch %d, Step %d, lr %.4f, Duration %.2f, Training nat acc %.2f, Training adv acc %.2f, Training adv loss %.4f"
                % (epoch, batch_idx, lr, duration, 100 * nat_acc,
                   100 * adv_acc, train_loss))

    if epoch % args.save_epochs == 0 or epoch >= args.max_epoch - 2:
        print('Saving..')
        f_path = os.path.join(args.model_dir, ('checkpoint-%s' % epoch))
        state = {
            'net': net.state_dict(),
            'epoch': epoch,
            #'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir(args.model_dir):
            os.makedirs(args.model_dir)
        torch.save(state, f_path)

    if epoch >= 1:
        print('Saving latest model for epoch %s..' % (epoch))
        f_path = os.path.join(args.model_dir, 'latest')
        state = {
            'net': net.state_dict(),
            'epoch': epoch,
            #'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir(args.model_dir):
            os.mkdir(args.model_dir)
        torch.save(state, f_path)