def main():
    """
    This function will parse argments, prepare data and prepare pretrained embedding
    """
    args = parser.parse_args()
    global_config = configs.__dict__[args.config]()

    if args.epoch_num != None:
        global_config.epoch_num = args.epoch_num

    print("net_name: ", args.model_name)
    net = models.__dict__[args.model_name](global_config)

    # get word_dict
    word_dict = utils.getDict(data_type="quora_question_pairs")

    # get reader
    train_reader, dev_reader, test_reader = utils.prepare_data(
        "quora_question_pairs",
        word_dict=word_dict,
        batch_size=global_config.batch_size,
        buf_size=800000,
        duplicate_data=global_config.duplicate_data,
        use_pad=(not global_config.use_lod_tensor))

    # load pretrained_word_embedding
    if global_config.use_pretrained_word_embedding:
        word2vec = Glove840B_300D(
            filepath=os.path.join(DATA_DIR, "glove.840B.300d.txt"),
            keys=set(word_dict.keys()))
        pretrained_word_embedding = utils.get_pretrained_word_embedding(
            word2vec=word2vec, word2id=word_dict, config=global_config)
        print("pretrained_word_embedding to be load:",
              pretrained_word_embedding)
    else:
        pretrained_word_embedding = None

    # define optimizer
    optimizer = utils.getOptimizer(global_config)

    # use cuda or not
    if not global_config.has_member('use_cuda'):
        if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ[
                'CUDA_VISIBLE_DEVICES'] != '':
            global_config.use_cuda = True
        else:
            global_config.use_cuda = False

    global_config.list_config()

    train_and_evaluate(
        train_reader,
        dev_reader,
        test_reader,
        net,
        optimizer,
        global_config,
        pretrained_word_embedding,
        use_cuda=global_config.use_cuda,
        parallel=False)
Пример #2
0
    def get_adv_by_convex_syn(self, embd, y, syn, syn_valid, text_like_syn,
                              attack_type_dict, text_for_vis, record_for_vis):

        # record context
        self_training_context = self.training
        # set context
        if self.eval_adv_mode:
            self.eval()
        else:
            self.train()

        device = embd.device
        # get param of attacks

        num_steps = attack_type_dict['num_steps']
        loss_func = attack_type_dict['loss_func']
        w_optm_lr = attack_type_dict['w_optm_lr']
        sparse_weight = attack_type_dict['sparse_weight']
        out_type = attack_type_dict['out_type']

        batch_size, text_len, embd_dim = embd.shape
        batch_size, text_len, syn_num, embd_dim = syn.shape

        w = torch.empty(batch_size, text_len, syn_num,
                        1).to(device).to(embd.dtype)
        #ww = torch.zeros(batch_size, text_len, syn_num, 1).to(device).to(embd.dtype)
        #ww = ww+500*(syn_valid.reshape(batch_size, text_len, syn_num, 1)-1)
        nn.init.kaiming_normal_(w)
        w.requires_grad_()

        import utils
        params = [w]
        optimizer = utils.getOptimizer(params,
                                       name='adam',
                                       lr=w_optm_lr,
                                       weight_decay=2e-5)

        def get_comb_p(w, syn_valid):
            ww = w * syn_valid.reshape(
                batch_size, text_len, syn_num, 1) + 500 * (
                    syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1)
            return F.softmax(ww, -2)

        def get_comb_ww(w, syn_valid):
            ww = w * syn_valid.reshape(
                batch_size, text_len, syn_num, 1) + 500 * (
                    syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1)
            return ww

        def get_comb(p, syn):
            return (p * syn.detach()).sum(-2)

        embd_ori = embd.detach()
        logit_ori = self.embd_to_logit(embd_ori)

        for _ in range(num_steps):
            optimizer.zero_grad()
            with torch.enable_grad():
                ww = get_comb_ww(w, syn_valid)
                #comb_p = get_comb_p(w, syn_valid)
                embd_adv = get_comb(F.softmax(ww, -2), syn)
                if loss_func == 'ce':
                    logit_adv = self.embd_to_logit(embd_adv)
                    loss = -F.cross_entropy(logit_adv, y, reduction='sum')
                elif loss_func == 'kl':
                    logit_adv = self.embd_to_logit(embd_adv)
                    criterion_kl = nn.KLDivLoss(reduction="sum")
                    loss = -criterion_kl(F.log_softmax(logit_adv, dim=1),
                                         F.softmax(logit_ori.detach(), dim=1))

                #print("ad loss:", loss.data.item())

                if sparse_weight != 0:
                    #loss_sparse = (comb_p*comb_p).mean()
                    loss_sparse = (-F.softmax(ww, -2) *
                                   F.log_softmax(ww, -2)).sum(-2).mean()
                    #loss -= sparse_weight*loss_sparse

                    loss = loss + sparse_weight * loss_sparse
                    #print(loss_sparse.data.item())

            #loss*=1000
            loss.backward()
            optimizer.step()

        #print((ww-w).max())

        comb_p = get_comb_p(w, syn_valid)

        if self.opt.vis_w_key_token is not None:
            assert (text_for_vis is not None and record_for_vis is not None)
            vis_n, vis_l = text_for_vis.shape
            for i in range(vis_n):
                for j in range(vis_l):
                    if text_for_vis[i, j] == self.opt.vis_w_key_token:
                        record_for_vis["comb_p_list"].append(
                            comb_p[i, j].cpu().detach().numpy())
                        record_for_vis["embd_syn_list"].append(
                            syn[i, j].cpu().detach().numpy())
                        record_for_vis["syn_valid_list"].append(
                            syn_valid[i, j].cpu().detach().numpy())
                        record_for_vis["text_syn_list"].append(
                            text_like_syn[i, j].cpu().detach().numpy())

                        print("record for vis",
                              len(record_for_vis["comb_p_list"]))
                    if len(record_for_vis["comb_p_list"]) >= 300:
                        dir_name = self.opt.resume.split(self.opt.model)[0]
                        file_name = self.opt.dataset + "_vis_w_" + str(
                            self.opt.attack_sparse_weight) + "_" + str(
                                self.opt.vis_w_key_token) + ".pkl"
                        file_name = os.path.join(dir_name, file_name)
                        f = open(file_name, 'wb')
                        pickle.dump(record_for_vis, f)
                        f.close()
                        sys.exit()

        if out_type == "text":
            # need to be fix, has potential bugs. the trigger dependes on data.
            assert (text_like_syn is not None)  # n l synlen
            comb_p = comb_p.reshape(batch_size * text_len, syn_num)
            ind = comb_p.max(-1)[1]  # shape batch_size* text_len
            out = (text_like_syn.reshape(
                batch_size * text_len,
                syn_num)[np.arange(batch_size * text_len),
                         ind]).reshape(batch_size, text_len)
        elif out_type == "comb_p":
            out = comb_p

        # resume context
        if self_training_context == True:
            self.train()
        else:
            self.eval()

        return out.detach()
Пример #3
0
import gym
import world
import utils
from Buffer import ReplayBuffer
from models import DQN
from world import Print, ARGS
from wrapper import WrapIt
from procedure import train_DQN

# ------------------------------------------------
env = gym.make('RiverraidNoFrameskip-v4')
env = WrapIt(env)
Print('ENV action', env.unwrapped.get_action_meanings())
Print('ENV observation', f"Image: {ARGS.imgDIM} X {ARGS.imgDIM} X {1}"
      )  # we assert to use gray image
# ------------------------------------------------
Optimizer = utils.getOptimizer()
schedule = utils.LinearSchedule(1000000, 0.1)

Game_buffer = ReplayBuffer(ARGS.buffersize, ARGS.framelen)

Q = utils.init_model(env, DQN).train().to(world.DEVICE)
Q_target = utils.init_model(env, DQN).eval().to(world.DEVICE)
# ------------------------------------------------
train_DQN(env,
          Q=Q,
          Q_target=Q_target,
          optimizer=Optimizer,
          replay_buffer=Game_buffer,
          exploration=schedule)
Пример #4
0
def train(opt, train_iter, test_iter, verbose=True):
    global_start = time.time()
    logger = utils.getLogger()
    model = models.setup(opt)
    if torch.cuda.is_available():
        model.cuda()
    params = [param for param in model.parameters() if param.requires_grad
              ]  #filter(lambda p: p.requires_grad, model.parameters())

    model_info = ";".join([
        str(k) + ":" + str(v) for k, v in opt.__dict__.items()
        if type(v) in (str, int, float, list, bool)
    ])
    logger.info("# parameters:" + str(sum(param.numel() for param in params)))
    logger.info(model_info)

    model.train()
    optimizer = utils.getOptimizer(params,
                                   name=opt.optimizer,
                                   lr=opt.learning_rate,
                                   scheduler=utils.get_lr_scheduler(
                                       opt.lr_scheduler))

    loss_fun = F.cross_entropy

    filename = None
    percisions = []
    for i in range(opt.max_epoch):
        for epoch, batch in enumerate(train_iter):
            optimizer.zero_grad()
            start = time.time()

            text = batch.text[0] if opt.from_torchtext else batch.text
            predicted = model(text)

            loss = loss_fun(predicted, batch.label)

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()

            if verbose:
                if torch.cuda.is_available():
                    logger.info(
                        "%d iteration %d epoch with loss : %.5f in %.4f seconds"
                        % (i, epoch, loss.cpu().data.numpy(),
                           time.time() - start))
                else:
                    logger.info(
                        "%d iteration %d epoch with loss : %.5f in %.4f seconds"
                        %
                        (i, epoch, loss.data.numpy()[0], time.time() - start))

        percision = utils.evaluation(model, test_iter, opt.from_torchtext)
        if verbose:
            logger.info("%d iteration with percision %.4f" % (i, percision))
        if len(percisions) == 0 or percision > max(percisions):
            if filename:
                os.remove(filename)
            filename = model.save(metric=percision)
        percisions.append(percision)


#    while(utils.is_writeable(performance_log_file)):
    df = pd.read_csv(performance_log_file, index_col=0, sep="\t")
    df.loc[model_info, opt.dataset] = max(percisions)
    df.to_csv(performance_log_file, sep="\t")
    logger.info(model_info + " with time :" + str(time.time() - global_start) +
                " ->" + str(max(percisions)))
    print(model_info + " with time :" + str(time.time() - global_start) +
          " ->" + str(max(percisions)))
Пример #5
0
def main():

    # loading config file ...
    cfgPath = sys.argv[1] if len(sys.argv) > 1 else './config.toml'
    cfg = loadConfig(cfgPath)

    try:
        # ... and unpacking variables
        dictget = lambda d, *k: [d[i] for i in k]

        dataStats = cfg['data_stats']
        modelParams = cfg['model_params']
        trainCSV, testCSV = dictget(cfg['database'], 'train', 'test')
        seqLength, stepSize = dictget(cfg['model_params'], 'seqLength',
                                      'stepSize')
        modelArch, modelDir, modelName = dictget(cfg['model_arch'],
                                                 'modelArch', 'modelDir',
                                                 'modelName')
        optimizer, lossFunc, metricFuncs = dictget(cfg['training_params'],
                                                   'optimizer', 'lossFunc',
                                                   'metricFuncs')
        lr, epochs, batchSize, patience, = dictget(cfg['training_params'],
                                                   'learningRate', 'epochs',
                                                   'batchSize', 'patience')
    except KeyError as err:
        print("\n\nERROR: not all parameters defined in config.toml : ", err)
        print("Exiting ... \n\n")
        sys.exit(1)

    print("Loading training data ...")
    xTrain, yTrain, stats = getData(trainCSV,
                                    seqLength=seqLength,
                                    stepSize=stepSize,
                                    stats=dataStats)
    print("Training Data Shape : ", xTrain.shape, "\n")

    print("Loading testing data ...")
    xTest, yTest, stats = getData(testCSV,
                                  seqLength=seqLength,
                                  stepSize=stepSize,
                                  stats=dataStats)
    print("Testing Data Shape : ", xTest.shape, "\n")

    yTrain = np.expand_dims(
        yTrain, -1)  # adding extra axis as model expects 2 axis in the output
    yTest = np.expand_dims(yTest, -1)

    print("Compiling Model")
    opt = getOptimizer(optimizer, lr)
    model = makeModel(modelArch, modelParams, verbose=True)
    model.compile(loss=lossFunc, optimizer=opt, metrics=metricFuncs)

    # setting up directories
    modelFolder = os.path.join(modelDir, modelName)
    weightsFolder = os.path.join(modelFolder, "weights")
    bestModelPath = os.path.join(weightsFolder, "best.hdf5")
    ensureDir(bestModelPath)

    saveConfig(cfgPath, modelFolder)

    # callbacks
    monitorMetric = 'val_loss'
    check1 = ModelCheckpoint(os.path.join(weightsFolder,
                                          modelName + "_{epoch:03d}.hdf5"),
                             monitor=monitorMetric,
                             mode='auto')
    check2 = ModelCheckpoint(bestModelPath,
                             monitor=monitorMetric,
                             save_best_only=True,
                             mode='auto')
    check3 = EarlyStopping(monitor=monitorMetric,
                           min_delta=0.01,
                           patience=patience,
                           verbose=0,
                           mode='auto')
    check4 = CSVLogger(os.path.join(modelFolder,
                                    modelName + '_trainingLog.csv'),
                       separator=',',
                       append=True)
    check5 = ReduceLROnPlateau(monitor=monitorMetric,
                               factor=0.1,
                               patience=patience // 3,
                               verbose=1,
                               mode='auto',
                               min_delta=0.001,
                               cooldown=0,
                               min_lr=1e-10)

    cb = [check2, check3, check4, check5]
    if cfg['training_params']['saveAllWeights']:
        cb.append(check1)

    print("Starting Training ...")
    model.fit(x=xTrain,
              y=yTrain,
              batch_size=batchSize,
              epochs=epochs,
              verbose=1,
              callbacks=cb,
              validation_data=(xTest, yTest),
              shuffle=True)
Пример #6
0
def train(opt, train_iter, dev_iter, test_iter, syn_data, verbose=True):
    global_start = time.time()
    #logger = utils.getLogger()
    model = models.setup(opt)

    if opt.resume != None:
        model = set_params(model, opt.resume)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.is_available():
        model.cuda()
        #model=torch.nn.DataParallel(model)

    # set optimizer
    if opt.embd_freeze == True:
        model.embedding.weight.requires_grad = False
    else:
        model.embedding.weight.requires_grad = True
    params = [param for param in model.parameters() if param.requires_grad
              ]  #filter(lambda p: p.requires_grad, model.parameters())
    optimizer = utils.getOptimizer(params,
                                   name=opt.optimizer,
                                   lr=opt.learning_rate,
                                   weight_decay=opt.weight_decay,
                                   scheduler=utils.get_lr_scheduler(
                                       opt.lr_scheduler))
    scheduler = WarmupMultiStepLR(optimizer, (40, 80), 0.1, 1.0 / 10.0, 2,
                                  'linear')

    from label_smooth import LabelSmoothSoftmaxCE
    if opt.label_smooth != 0:
        assert (opt.label_smooth <= 1 and opt.label_smooth > 0)
        loss_fun = LabelSmoothSoftmaxCE(lb_pos=1 - opt.label_smooth,
                                        lb_neg=opt.label_smooth)
    else:
        loss_fun = F.cross_entropy

    filename = None
    acc_adv_list = []
    start = time.time()
    kl_control = 0

    # initialize synonyms with the same embd
    from PWWS.word_level_process import word_process, get_tokenizer
    tokenizer = get_tokenizer(opt)

    if opt.embedding_prep == "same":
        father_dict = {}
        for index in range(1 + len(tokenizer.index_word)):
            father_dict[index] = index

        def get_father(x):
            if father_dict[x] == x:
                return x
            else:
                fa = get_father(father_dict[x])
                father_dict[x] = fa
                return fa

        for index in range(len(syn_data) - 1, 0, -1):
            syn_list = syn_data[index]
            for pos in syn_list:
                fa_pos = get_father(pos)
                fa_anch = get_father(index)
                if fa_pos == fa_anch:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                else:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                    father_dict[fa_pos] = index

        print("Same embedding for synonyms as embd prep.")
        set_different_embd = set()
        for key in father_dict:
            fa = get_father(key)
            set_different_embd.add(fa)
            with torch.no_grad():
                model.embedding.weight[key, :] = model.embedding.weight[fa, :]
        print(len(set_different_embd))

    elif opt.embedding_prep == "ge":
        print("Graph embedding as embd prep.")
        ge_file_path = opt.ge_file_path
        f = open(ge_file_path, 'rb')
        saved = pickle.load(f)
        ge_embeddings_dict = saved['walk_embeddings']
        #model = saved['model']
        f.close()
        with torch.no_grad():
            for key in ge_embeddings_dict:
                model.embedding.weight[int(key), :] = torch.FloatTensor(
                    ge_embeddings_dict[key])
    else:
        print("No embd prep.")

    from from_certified.attack_surface import WordSubstitutionAttackSurface, LMConstrainedAttackSurface
    if opt.lm_constraint:
        attack_surface = LMConstrainedAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)
    else:
        attack_surface = WordSubstitutionAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)

    best_adv_acc = 0
    for epoch in range(21):

        if opt.smooth_ce:
            if epoch < 10:
                weight_adv = epoch * 1.0 / 10
                weight_clean = 1 - weight_adv
            else:
                weight_adv = 1
                weight_clean = 0
        else:
            weight_adv = opt.weight_adv
            weight_clean = opt.weight_clean

        if epoch >= opt.kl_start_epoch:
            kl_control = 1

        sum_loss = sum_loss_adv = sum_loss_kl = sum_loss_clean = 0
        total = 0

        for iters, batch in enumerate(train_iter):

            text = batch[0].to(device)
            label = batch[1].to(device)
            anch = batch[2].to(device)
            pos = batch[3].to(device)
            neg = batch[4].to(device)
            anch_valid = batch[5].to(device).unsqueeze(2)
            text_like_syn = batch[6].to(device)
            text_like_syn_valid = batch[7].to(device)

            bs, sent_len = text.shape

            model.train()

            # zero grad
            optimizer.zero_grad()

            if opt.pert_set == "ad_text":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.attack_sparse_weight,
                    'out_type': "text"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                text_adv = model(mode="get_adv_by_convex_syn",
                                 input=embd,
                                 label=label,
                                 text_like_syn_embd=text_like_syn_embd,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_syn_p":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.train_attack_sparse_weight,
                    'out_type': "comb_p"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                adv_comb_p = model(mode="get_adv_by_convex_syn",
                                   input=embd,
                                   label=label,
                                   text_like_syn_embd=text_like_syn_embd,
                                   text_like_syn_valid=text_like_syn_valid,
                                   attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_hotflip":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                }
                text_adv = model(mode="get_adv_hotflip",
                                 input=text,
                                 label=label,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "l2_ball":
                set_radius = opt.train_attack_eps
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'step_size': opt.train_attack_step_size * set_radius,
                    'random_start': opt.random_start,
                    'epsilon': set_radius,
                    #'loss_func': 'ce',
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'direction': 'away',
                    'ball_range': opt.l2_ball_range,
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                embd_adv = model(mode="get_embd_adv",
                                 input=embd,
                                 label=label,
                                 attack_type_dict=attack_type_dict)

            optimizer.zero_grad()
            # clean loss
            predicted = model(mode="text_to_logit", input=text)
            loss_clean = loss_fun(predicted, label)
            # adv loss
            if opt.pert_set == "ad_text" or opt.pert_set == "ad_text_hotflip":
                predicted_adv = model(mode="text_to_logit", input=text_adv)
            elif opt.pert_set == "ad_text_syn_p":
                predicted_adv = model(mode="text_syn_p_to_logit",
                                      input=text_like_syn,
                                      comb_p=adv_comb_p)
            elif opt.pert_set == "l2_ball":
                predicted_adv = model(mode="embd_to_logit", input=embd_adv)

            loss_adv = loss_fun(predicted_adv, label)
            # kl loss
            criterion_kl = nn.KLDivLoss(reduction="sum")
            loss_kl = (1.0 / bs) * criterion_kl(
                F.log_softmax(predicted_adv, dim=1), F.softmax(predicted,
                                                               dim=1))

            # optimize
            loss = opt.weight_kl * kl_control * loss_kl + weight_adv * loss_adv + weight_clean * loss_clean
            loss.backward()
            optimizer.step()
            sum_loss += loss.item()
            sum_loss_adv += loss_adv.item()
            sum_loss_clean += loss_clean.item()
            sum_loss_kl += loss_kl.item()
            predicted, idx = torch.max(predicted, 1)
            precision = (idx == label).float().mean().item()
            predicted_adv, idx = torch.max(predicted_adv, 1)
            precision_adv = (idx == label).float().mean().item()
            total += 1

            out_log = "%d epoch %d iters: loss: %.3f, loss_kl: %.3f, loss_adv: %.3f, loss_clean: %.3f | acc: %.3f acc_adv: %.3f | in %.3f seconds" % (
                epoch, iters, sum_loss / total, sum_loss_kl / total,
                sum_loss_adv / total, sum_loss_clean / total, precision,
                precision_adv, time.time() - start)
            start = time.time()
            print(out_log)

        scheduler.step()

        if epoch % 1 == 0:
            acc = utils.imdb_evaluation(opt, device, model, dev_iter)
            out_log = "%d epoch with dev acc %.4f" % (epoch, acc)
            print(out_log)
            adv_acc = utils.imdb_evaluation_ascc_attack(
                opt, device, model, dev_iter, tokenizer)
            out_log = "%d epoch with dev adv acc against ascc attack %.4f" % (
                epoch, adv_acc)
            print(out_log)

            #hotflip_adv_acc=utils.evaluation_hotflip_adv(opt, device, model, dev_iter, tokenizer)
            #out_log="%d epoch with dev hotflip adv acc %.4f" % (epoch,hotflip_adv_acc)
            #logger.info(out_log)
            #print(out_log)

            if adv_acc >= best_adv_acc:
                best_adv_acc = adv_acc
                best_save_dir = os.path.join(opt.out_path,
                                             "{}_best.pth".format(opt.model))
                state = {
                    'net': model.state_dict(),
                    'epoch': epoch,
                }
                torch.save(state, best_save_dir)

    # restore best according to dev set
    model = set_params(model, best_save_dir)
    acc = utils.imdb_evaluation(opt, device, model, test_iter)
    print("test acc %.4f" % (acc))
    adv_acc = utils.imdb_evaluation_ascc_attack(opt, device, model, test_iter,
                                                tokenizer)
    print("test adv acc against ascc attack %.4f" % (adv_acc))
    genetic_attack(opt,
                   device,
                   model,
                   attack_surface,
                   dataset=opt.dataset,
                   genetic_test_num=opt.genetic_test_num)
    fool_text_classifier_pytorch(opt,
                                 device,
                                 model,
                                 dataset=opt.dataset,
                                 clean_samples_cap=opt.pwws_test_num)
Пример #7
0
    def get_adv_by_convex_syn(self, embd_p, embd_h, y, x_p_text_like_syn,
                              x_p_syn_embd, x_p_syn_valid, x_h_text_like_syn,
                              x_h_syn_embd, x_h_syn_valid, x_p_mask, x_h_mask,
                              attack_type_dict):

        #noted that if attack hypo only then the output x_p_comb_p is meaningless

        # record context
        self_training_context = self.training
        # set context
        if self.eval_adv_mode:
            self.eval()
        else:
            self.train()

        device = embd_p.device
        # get param of attacks

        num_steps = attack_type_dict['num_steps']
        loss_func = attack_type_dict['loss_func']
        w_optm_lr = attack_type_dict['w_optm_lr']
        sparse_weight = attack_type_dict['sparse_weight']
        out_type = attack_type_dict['out_type']
        attack_hypo_only = attack_type_dict[
            'attack_hypo_only'] if 'attack_hypo_only' in attack_type_dict else True

        batch_size, text_len, embd_dim = embd_p.shape
        batch_size, text_len, syn_num, embd_dim = x_p_syn_embd.shape

        w_p = torch.empty(batch_size, text_len, syn_num,
                          1).to(device).to(embd_p.dtype)
        w_h = torch.empty(batch_size, text_len, syn_num,
                          1).to(device).to(embd_p.dtype)
        #ww = torch.zeros(batch_size, text_len, syn_num, 1).to(device).to(embd.dtype)
        #ww = ww+500*(syn_valid.reshape(batch_size, text_len, syn_num, 1)-1)
        nn.init.kaiming_normal_(w_p)
        nn.init.kaiming_normal_(w_h)
        w_p.requires_grad_()
        w_h.requires_grad_()

        import utils
        params = [w_p, w_h]
        optimizer = utils.getOptimizer(params,
                                       name='adam',
                                       lr=w_optm_lr,
                                       weight_decay=2e-5)

        def get_comb_p(w, syn_valid):
            ww = w * syn_valid.reshape(
                batch_size, text_len, syn_num, 1) + 10000 * (
                    syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1)
            return F.softmax(ww, -2)

        def get_comb_ww(w, syn_valid):
            ww = w * syn_valid.reshape(
                batch_size, text_len, syn_num, 1) + 10000 * (
                    syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1)
            return ww

        def get_comb(p, syn):
            return (p * syn.detach()).sum(-2)

        embd_p_ori = embd_p.detach()
        embd_h_ori = embd_h.detach()
        logit_ori = self.embd_to_logit(embd_p_ori, embd_h_ori, x_p_mask,
                                       x_h_mask)

        for _ in range(num_steps):
            optimizer.zero_grad()
            with torch.enable_grad():
                ww_p = get_comb_ww(w_p, x_p_syn_valid)
                ww_h = get_comb_ww(w_h, x_h_syn_valid)
                #comb_p = get_comb_p(w, syn_valid)
                embd_p_adv = get_comb(F.softmax(ww_p, -2), x_p_syn_embd)
                embd_h_adv = get_comb(F.softmax(ww_h, -2), x_h_syn_embd)
                if attack_hypo_only:
                    logit_adv = self.embd_to_logit(embd_p_ori, embd_h_adv,
                                                   x_p_mask, x_h_mask)
                else:
                    logit_adv = self.embd_to_logit(embd_p_adv, embd_h_adv,
                                                   x_p_mask, x_h_mask)

                if loss_func == 'ce':
                    loss = -F.cross_entropy(logit_adv, y, reduction='sum')
                elif loss_func == 'kl':
                    criterion_kl = nn.KLDivLoss(reduction="sum")
                    loss = -criterion_kl(F.log_softmax(logit_adv, dim=1),
                                         F.softmax(logit_ori.detach(), dim=1))

                #print("ad loss:", loss.data.item())

                if sparse_weight != 0:
                    #loss_sparse = (comb_p*comb_p).mean()
                    if attack_hypo_only:
                        loss_sparse = (-F.softmax(ww_h, -2) *
                                       F.log_softmax(ww_h, -2)).sum(-2).mean()
                    else:
                        loss_sparse = (
                            (-F.softmax(ww_p, -2) *
                             F.log_softmax(ww_p, -2)).sum(-2).mean() +
                            (-F.softmax(ww_h, -2) *
                             F.log_softmax(ww_h, -2)).sum(-2).mean()) / 2
                    #loss -= sparse_weight*loss_sparse

                    loss = loss + sparse_weight * loss_sparse
                    #print(loss_sparse.data.item())

            #loss*=1000
            loss.backward()
            optimizer.step()

        #print((ww-w).max())

        x_p_comb_p = get_comb_p(w_p, x_p_syn_valid)
        x_h_comb_p = get_comb_p(w_h, x_h_syn_valid)
        """
        out = get_comb(comb_p, syn)
        delta = (out-embd_ori).reshape(batch_size*text_len,embd_dim)
        delta = F.pairwise_distance(delta, torch.zeros_like(delta), p=2.0)
        valid = (delta>0.01).to(device).to(delta.dtype)
        delta = (valid*delta).sum()/valid.sum()
        print("mean l2 dis between embd and embd_adv:", delta.data.item())
        #print("mean max comb_p:", (comb_p.max(-2)[0]).mean().data.item())
        """

        # resume context
        if self_training_context == True:
            self.train()
        else:
            self.eval()

        if out_type == "comb_p":
            return x_p_comb_p.detach(), x_h_comb_p.detach()
        elif out_type == "text":
            assert (x_p_text_like_syn is not None)  # n l synlen
            assert (x_h_text_like_syn is not None)  # n l synlen
            x_p_comb_p = x_p_comb_p.reshape(batch_size * text_len, syn_num)
            x_h_comb_p = x_h_comb_p.reshape(batch_size * text_len, syn_num)
            ind_x_p = x_p_comb_p.max(-1)[1]  # shape batch_size* text_len
            ind_x_h = x_h_comb_p.max(-1)[1]  # shape batch_size* text_len
            adv_text_x_p = (x_p_text_like_syn.reshape(
                batch_size * text_len,
                syn_num)[np.arange(batch_size * text_len),
                         ind_x_p]).reshape(batch_size, text_len)
            adv_text_x_h = (x_h_text_like_syn.reshape(
                batch_size * text_len,
                syn_num)[np.arange(batch_size * text_len),
                         ind_x_h]).reshape(batch_size, text_len)
            return adv_text_x_p, adv_text_x_h