def main(): """ This function will parse argments, prepare data and prepare pretrained embedding """ args = parser.parse_args() global_config = configs.__dict__[args.config]() if args.epoch_num != None: global_config.epoch_num = args.epoch_num print("net_name: ", args.model_name) net = models.__dict__[args.model_name](global_config) # get word_dict word_dict = utils.getDict(data_type="quora_question_pairs") # get reader train_reader, dev_reader, test_reader = utils.prepare_data( "quora_question_pairs", word_dict=word_dict, batch_size=global_config.batch_size, buf_size=800000, duplicate_data=global_config.duplicate_data, use_pad=(not global_config.use_lod_tensor)) # load pretrained_word_embedding if global_config.use_pretrained_word_embedding: word2vec = Glove840B_300D( filepath=os.path.join(DATA_DIR, "glove.840B.300d.txt"), keys=set(word_dict.keys())) pretrained_word_embedding = utils.get_pretrained_word_embedding( word2vec=word2vec, word2id=word_dict, config=global_config) print("pretrained_word_embedding to be load:", pretrained_word_embedding) else: pretrained_word_embedding = None # define optimizer optimizer = utils.getOptimizer(global_config) # use cuda or not if not global_config.has_member('use_cuda'): if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ[ 'CUDA_VISIBLE_DEVICES'] != '': global_config.use_cuda = True else: global_config.use_cuda = False global_config.list_config() train_and_evaluate( train_reader, dev_reader, test_reader, net, optimizer, global_config, pretrained_word_embedding, use_cuda=global_config.use_cuda, parallel=False)
def get_adv_by_convex_syn(self, embd, y, syn, syn_valid, text_like_syn, attack_type_dict, text_for_vis, record_for_vis): # record context self_training_context = self.training # set context if self.eval_adv_mode: self.eval() else: self.train() device = embd.device # get param of attacks num_steps = attack_type_dict['num_steps'] loss_func = attack_type_dict['loss_func'] w_optm_lr = attack_type_dict['w_optm_lr'] sparse_weight = attack_type_dict['sparse_weight'] out_type = attack_type_dict['out_type'] batch_size, text_len, embd_dim = embd.shape batch_size, text_len, syn_num, embd_dim = syn.shape w = torch.empty(batch_size, text_len, syn_num, 1).to(device).to(embd.dtype) #ww = torch.zeros(batch_size, text_len, syn_num, 1).to(device).to(embd.dtype) #ww = ww+500*(syn_valid.reshape(batch_size, text_len, syn_num, 1)-1) nn.init.kaiming_normal_(w) w.requires_grad_() import utils params = [w] optimizer = utils.getOptimizer(params, name='adam', lr=w_optm_lr, weight_decay=2e-5) def get_comb_p(w, syn_valid): ww = w * syn_valid.reshape( batch_size, text_len, syn_num, 1) + 500 * ( syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1) return F.softmax(ww, -2) def get_comb_ww(w, syn_valid): ww = w * syn_valid.reshape( batch_size, text_len, syn_num, 1) + 500 * ( syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1) return ww def get_comb(p, syn): return (p * syn.detach()).sum(-2) embd_ori = embd.detach() logit_ori = self.embd_to_logit(embd_ori) for _ in range(num_steps): optimizer.zero_grad() with torch.enable_grad(): ww = get_comb_ww(w, syn_valid) #comb_p = get_comb_p(w, syn_valid) embd_adv = get_comb(F.softmax(ww, -2), syn) if loss_func == 'ce': logit_adv = self.embd_to_logit(embd_adv) loss = -F.cross_entropy(logit_adv, y, reduction='sum') elif loss_func == 'kl': logit_adv = self.embd_to_logit(embd_adv) criterion_kl = nn.KLDivLoss(reduction="sum") loss = -criterion_kl(F.log_softmax(logit_adv, dim=1), F.softmax(logit_ori.detach(), dim=1)) #print("ad loss:", loss.data.item()) if sparse_weight != 0: #loss_sparse = (comb_p*comb_p).mean() loss_sparse = (-F.softmax(ww, -2) * F.log_softmax(ww, -2)).sum(-2).mean() #loss -= sparse_weight*loss_sparse loss = loss + sparse_weight * loss_sparse #print(loss_sparse.data.item()) #loss*=1000 loss.backward() optimizer.step() #print((ww-w).max()) comb_p = get_comb_p(w, syn_valid) if self.opt.vis_w_key_token is not None: assert (text_for_vis is not None and record_for_vis is not None) vis_n, vis_l = text_for_vis.shape for i in range(vis_n): for j in range(vis_l): if text_for_vis[i, j] == self.opt.vis_w_key_token: record_for_vis["comb_p_list"].append( comb_p[i, j].cpu().detach().numpy()) record_for_vis["embd_syn_list"].append( syn[i, j].cpu().detach().numpy()) record_for_vis["syn_valid_list"].append( syn_valid[i, j].cpu().detach().numpy()) record_for_vis["text_syn_list"].append( text_like_syn[i, j].cpu().detach().numpy()) print("record for vis", len(record_for_vis["comb_p_list"])) if len(record_for_vis["comb_p_list"]) >= 300: dir_name = self.opt.resume.split(self.opt.model)[0] file_name = self.opt.dataset + "_vis_w_" + str( self.opt.attack_sparse_weight) + "_" + str( self.opt.vis_w_key_token) + ".pkl" file_name = os.path.join(dir_name, file_name) f = open(file_name, 'wb') pickle.dump(record_for_vis, f) f.close() sys.exit() if out_type == "text": # need to be fix, has potential bugs. the trigger dependes on data. assert (text_like_syn is not None) # n l synlen comb_p = comb_p.reshape(batch_size * text_len, syn_num) ind = comb_p.max(-1)[1] # shape batch_size* text_len out = (text_like_syn.reshape( batch_size * text_len, syn_num)[np.arange(batch_size * text_len), ind]).reshape(batch_size, text_len) elif out_type == "comb_p": out = comb_p # resume context if self_training_context == True: self.train() else: self.eval() return out.detach()
import gym import world import utils from Buffer import ReplayBuffer from models import DQN from world import Print, ARGS from wrapper import WrapIt from procedure import train_DQN # ------------------------------------------------ env = gym.make('RiverraidNoFrameskip-v4') env = WrapIt(env) Print('ENV action', env.unwrapped.get_action_meanings()) Print('ENV observation', f"Image: {ARGS.imgDIM} X {ARGS.imgDIM} X {1}" ) # we assert to use gray image # ------------------------------------------------ Optimizer = utils.getOptimizer() schedule = utils.LinearSchedule(1000000, 0.1) Game_buffer = ReplayBuffer(ARGS.buffersize, ARGS.framelen) Q = utils.init_model(env, DQN).train().to(world.DEVICE) Q_target = utils.init_model(env, DQN).eval().to(world.DEVICE) # ------------------------------------------------ train_DQN(env, Q=Q, Q_target=Q_target, optimizer=Optimizer, replay_buffer=Game_buffer, exploration=schedule)
def train(opt, train_iter, test_iter, verbose=True): global_start = time.time() logger = utils.getLogger() model = models.setup(opt) if torch.cuda.is_available(): model.cuda() params = [param for param in model.parameters() if param.requires_grad ] #filter(lambda p: p.requires_grad, model.parameters()) model_info = ";".join([ str(k) + ":" + str(v) for k, v in opt.__dict__.items() if type(v) in (str, int, float, list, bool) ]) logger.info("# parameters:" + str(sum(param.numel() for param in params))) logger.info(model_info) model.train() optimizer = utils.getOptimizer(params, name=opt.optimizer, lr=opt.learning_rate, scheduler=utils.get_lr_scheduler( opt.lr_scheduler)) loss_fun = F.cross_entropy filename = None percisions = [] for i in range(opt.max_epoch): for epoch, batch in enumerate(train_iter): optimizer.zero_grad() start = time.time() text = batch.text[0] if opt.from_torchtext else batch.text predicted = model(text) loss = loss_fun(predicted, batch.label) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() if verbose: if torch.cuda.is_available(): logger.info( "%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i, epoch, loss.cpu().data.numpy(), time.time() - start)) else: logger.info( "%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i, epoch, loss.data.numpy()[0], time.time() - start)) percision = utils.evaluation(model, test_iter, opt.from_torchtext) if verbose: logger.info("%d iteration with percision %.4f" % (i, percision)) if len(percisions) == 0 or percision > max(percisions): if filename: os.remove(filename) filename = model.save(metric=percision) percisions.append(percision) # while(utils.is_writeable(performance_log_file)): df = pd.read_csv(performance_log_file, index_col=0, sep="\t") df.loc[model_info, opt.dataset] = max(percisions) df.to_csv(performance_log_file, sep="\t") logger.info(model_info + " with time :" + str(time.time() - global_start) + " ->" + str(max(percisions))) print(model_info + " with time :" + str(time.time() - global_start) + " ->" + str(max(percisions)))
def main(): # loading config file ... cfgPath = sys.argv[1] if len(sys.argv) > 1 else './config.toml' cfg = loadConfig(cfgPath) try: # ... and unpacking variables dictget = lambda d, *k: [d[i] for i in k] dataStats = cfg['data_stats'] modelParams = cfg['model_params'] trainCSV, testCSV = dictget(cfg['database'], 'train', 'test') seqLength, stepSize = dictget(cfg['model_params'], 'seqLength', 'stepSize') modelArch, modelDir, modelName = dictget(cfg['model_arch'], 'modelArch', 'modelDir', 'modelName') optimizer, lossFunc, metricFuncs = dictget(cfg['training_params'], 'optimizer', 'lossFunc', 'metricFuncs') lr, epochs, batchSize, patience, = dictget(cfg['training_params'], 'learningRate', 'epochs', 'batchSize', 'patience') except KeyError as err: print("\n\nERROR: not all parameters defined in config.toml : ", err) print("Exiting ... \n\n") sys.exit(1) print("Loading training data ...") xTrain, yTrain, stats = getData(trainCSV, seqLength=seqLength, stepSize=stepSize, stats=dataStats) print("Training Data Shape : ", xTrain.shape, "\n") print("Loading testing data ...") xTest, yTest, stats = getData(testCSV, seqLength=seqLength, stepSize=stepSize, stats=dataStats) print("Testing Data Shape : ", xTest.shape, "\n") yTrain = np.expand_dims( yTrain, -1) # adding extra axis as model expects 2 axis in the output yTest = np.expand_dims(yTest, -1) print("Compiling Model") opt = getOptimizer(optimizer, lr) model = makeModel(modelArch, modelParams, verbose=True) model.compile(loss=lossFunc, optimizer=opt, metrics=metricFuncs) # setting up directories modelFolder = os.path.join(modelDir, modelName) weightsFolder = os.path.join(modelFolder, "weights") bestModelPath = os.path.join(weightsFolder, "best.hdf5") ensureDir(bestModelPath) saveConfig(cfgPath, modelFolder) # callbacks monitorMetric = 'val_loss' check1 = ModelCheckpoint(os.path.join(weightsFolder, modelName + "_{epoch:03d}.hdf5"), monitor=monitorMetric, mode='auto') check2 = ModelCheckpoint(bestModelPath, monitor=monitorMetric, save_best_only=True, mode='auto') check3 = EarlyStopping(monitor=monitorMetric, min_delta=0.01, patience=patience, verbose=0, mode='auto') check4 = CSVLogger(os.path.join(modelFolder, modelName + '_trainingLog.csv'), separator=',', append=True) check5 = ReduceLROnPlateau(monitor=monitorMetric, factor=0.1, patience=patience // 3, verbose=1, mode='auto', min_delta=0.001, cooldown=0, min_lr=1e-10) cb = [check2, check3, check4, check5] if cfg['training_params']['saveAllWeights']: cb.append(check1) print("Starting Training ...") model.fit(x=xTrain, y=yTrain, batch_size=batchSize, epochs=epochs, verbose=1, callbacks=cb, validation_data=(xTest, yTest), shuffle=True)
def train(opt, train_iter, dev_iter, test_iter, syn_data, verbose=True): global_start = time.time() #logger = utils.getLogger() model = models.setup(opt) if opt.resume != None: model = set_params(model, opt.resume) device = 'cuda' if torch.cuda.is_available() else 'cpu' if torch.cuda.is_available(): model.cuda() #model=torch.nn.DataParallel(model) # set optimizer if opt.embd_freeze == True: model.embedding.weight.requires_grad = False else: model.embedding.weight.requires_grad = True params = [param for param in model.parameters() if param.requires_grad ] #filter(lambda p: p.requires_grad, model.parameters()) optimizer = utils.getOptimizer(params, name=opt.optimizer, lr=opt.learning_rate, weight_decay=opt.weight_decay, scheduler=utils.get_lr_scheduler( opt.lr_scheduler)) scheduler = WarmupMultiStepLR(optimizer, (40, 80), 0.1, 1.0 / 10.0, 2, 'linear') from label_smooth import LabelSmoothSoftmaxCE if opt.label_smooth != 0: assert (opt.label_smooth <= 1 and opt.label_smooth > 0) loss_fun = LabelSmoothSoftmaxCE(lb_pos=1 - opt.label_smooth, lb_neg=opt.label_smooth) else: loss_fun = F.cross_entropy filename = None acc_adv_list = [] start = time.time() kl_control = 0 # initialize synonyms with the same embd from PWWS.word_level_process import word_process, get_tokenizer tokenizer = get_tokenizer(opt) if opt.embedding_prep == "same": father_dict = {} for index in range(1 + len(tokenizer.index_word)): father_dict[index] = index def get_father(x): if father_dict[x] == x: return x else: fa = get_father(father_dict[x]) father_dict[x] = fa return fa for index in range(len(syn_data) - 1, 0, -1): syn_list = syn_data[index] for pos in syn_list: fa_pos = get_father(pos) fa_anch = get_father(index) if fa_pos == fa_anch: father_dict[index] = index father_dict[fa_anch] = index else: father_dict[index] = index father_dict[fa_anch] = index father_dict[fa_pos] = index print("Same embedding for synonyms as embd prep.") set_different_embd = set() for key in father_dict: fa = get_father(key) set_different_embd.add(fa) with torch.no_grad(): model.embedding.weight[key, :] = model.embedding.weight[fa, :] print(len(set_different_embd)) elif opt.embedding_prep == "ge": print("Graph embedding as embd prep.") ge_file_path = opt.ge_file_path f = open(ge_file_path, 'rb') saved = pickle.load(f) ge_embeddings_dict = saved['walk_embeddings'] #model = saved['model'] f.close() with torch.no_grad(): for key in ge_embeddings_dict: model.embedding.weight[int(key), :] = torch.FloatTensor( ge_embeddings_dict[key]) else: print("No embd prep.") from from_certified.attack_surface import WordSubstitutionAttackSurface, LMConstrainedAttackSurface if opt.lm_constraint: attack_surface = LMConstrainedAttackSurface.from_files( opt.certified_neighbors_file_path, opt.imdb_lm_file_path) else: attack_surface = WordSubstitutionAttackSurface.from_files( opt.certified_neighbors_file_path, opt.imdb_lm_file_path) best_adv_acc = 0 for epoch in range(21): if opt.smooth_ce: if epoch < 10: weight_adv = epoch * 1.0 / 10 weight_clean = 1 - weight_adv else: weight_adv = 1 weight_clean = 0 else: weight_adv = opt.weight_adv weight_clean = opt.weight_clean if epoch >= opt.kl_start_epoch: kl_control = 1 sum_loss = sum_loss_adv = sum_loss_kl = sum_loss_clean = 0 total = 0 for iters, batch in enumerate(train_iter): text = batch[0].to(device) label = batch[1].to(device) anch = batch[2].to(device) pos = batch[3].to(device) neg = batch[4].to(device) anch_valid = batch[5].to(device).unsqueeze(2) text_like_syn = batch[6].to(device) text_like_syn_valid = batch[7].to(device) bs, sent_len = text.shape model.train() # zero grad optimizer.zero_grad() if opt.pert_set == "ad_text": attack_type_dict = { 'num_steps': opt.train_attack_iters, 'loss_func': 'ce' if opt.if_ce_adp else 'kl', 'w_optm_lr': opt.w_optm_lr, 'sparse_weight': opt.attack_sparse_weight, 'out_type': "text" } embd = model(mode="text_to_embd", input=text) #in bs, len sent, vocab n, l, s = text_like_syn.shape text_like_syn_embd = model(mode="text_to_embd", input=text_like_syn.reshape( n, l * s)).reshape(n, l, s, -1) text_adv = model(mode="get_adv_by_convex_syn", input=embd, label=label, text_like_syn_embd=text_like_syn_embd, text_like_syn_valid=text_like_syn_valid, text_like_syn=text_like_syn, attack_type_dict=attack_type_dict) elif opt.pert_set == "ad_text_syn_p": attack_type_dict = { 'num_steps': opt.train_attack_iters, 'loss_func': 'ce' if opt.if_ce_adp else 'kl', 'w_optm_lr': opt.w_optm_lr, 'sparse_weight': opt.train_attack_sparse_weight, 'out_type': "comb_p" } embd = model(mode="text_to_embd", input=text) #in bs, len sent, vocab n, l, s = text_like_syn.shape text_like_syn_embd = model(mode="text_to_embd", input=text_like_syn.reshape( n, l * s)).reshape(n, l, s, -1) adv_comb_p = model(mode="get_adv_by_convex_syn", input=embd, label=label, text_like_syn_embd=text_like_syn_embd, text_like_syn_valid=text_like_syn_valid, attack_type_dict=attack_type_dict) elif opt.pert_set == "ad_text_hotflip": attack_type_dict = { 'num_steps': opt.train_attack_iters, 'loss_func': 'ce' if opt.if_ce_adp else 'kl', } text_adv = model(mode="get_adv_hotflip", input=text, label=label, text_like_syn_valid=text_like_syn_valid, text_like_syn=text_like_syn, attack_type_dict=attack_type_dict) elif opt.pert_set == "l2_ball": set_radius = opt.train_attack_eps attack_type_dict = { 'num_steps': opt.train_attack_iters, 'step_size': opt.train_attack_step_size * set_radius, 'random_start': opt.random_start, 'epsilon': set_radius, #'loss_func': 'ce', 'loss_func': 'ce' if opt.if_ce_adp else 'kl', 'direction': 'away', 'ball_range': opt.l2_ball_range, } embd = model(mode="text_to_embd", input=text) #in bs, len sent, vocab embd_adv = model(mode="get_embd_adv", input=embd, label=label, attack_type_dict=attack_type_dict) optimizer.zero_grad() # clean loss predicted = model(mode="text_to_logit", input=text) loss_clean = loss_fun(predicted, label) # adv loss if opt.pert_set == "ad_text" or opt.pert_set == "ad_text_hotflip": predicted_adv = model(mode="text_to_logit", input=text_adv) elif opt.pert_set == "ad_text_syn_p": predicted_adv = model(mode="text_syn_p_to_logit", input=text_like_syn, comb_p=adv_comb_p) elif opt.pert_set == "l2_ball": predicted_adv = model(mode="embd_to_logit", input=embd_adv) loss_adv = loss_fun(predicted_adv, label) # kl loss criterion_kl = nn.KLDivLoss(reduction="sum") loss_kl = (1.0 / bs) * criterion_kl( F.log_softmax(predicted_adv, dim=1), F.softmax(predicted, dim=1)) # optimize loss = opt.weight_kl * kl_control * loss_kl + weight_adv * loss_adv + weight_clean * loss_clean loss.backward() optimizer.step() sum_loss += loss.item() sum_loss_adv += loss_adv.item() sum_loss_clean += loss_clean.item() sum_loss_kl += loss_kl.item() predicted, idx = torch.max(predicted, 1) precision = (idx == label).float().mean().item() predicted_adv, idx = torch.max(predicted_adv, 1) precision_adv = (idx == label).float().mean().item() total += 1 out_log = "%d epoch %d iters: loss: %.3f, loss_kl: %.3f, loss_adv: %.3f, loss_clean: %.3f | acc: %.3f acc_adv: %.3f | in %.3f seconds" % ( epoch, iters, sum_loss / total, sum_loss_kl / total, sum_loss_adv / total, sum_loss_clean / total, precision, precision_adv, time.time() - start) start = time.time() print(out_log) scheduler.step() if epoch % 1 == 0: acc = utils.imdb_evaluation(opt, device, model, dev_iter) out_log = "%d epoch with dev acc %.4f" % (epoch, acc) print(out_log) adv_acc = utils.imdb_evaluation_ascc_attack( opt, device, model, dev_iter, tokenizer) out_log = "%d epoch with dev adv acc against ascc attack %.4f" % ( epoch, adv_acc) print(out_log) #hotflip_adv_acc=utils.evaluation_hotflip_adv(opt, device, model, dev_iter, tokenizer) #out_log="%d epoch with dev hotflip adv acc %.4f" % (epoch,hotflip_adv_acc) #logger.info(out_log) #print(out_log) if adv_acc >= best_adv_acc: best_adv_acc = adv_acc best_save_dir = os.path.join(opt.out_path, "{}_best.pth".format(opt.model)) state = { 'net': model.state_dict(), 'epoch': epoch, } torch.save(state, best_save_dir) # restore best according to dev set model = set_params(model, best_save_dir) acc = utils.imdb_evaluation(opt, device, model, test_iter) print("test acc %.4f" % (acc)) adv_acc = utils.imdb_evaluation_ascc_attack(opt, device, model, test_iter, tokenizer) print("test adv acc against ascc attack %.4f" % (adv_acc)) genetic_attack(opt, device, model, attack_surface, dataset=opt.dataset, genetic_test_num=opt.genetic_test_num) fool_text_classifier_pytorch(opt, device, model, dataset=opt.dataset, clean_samples_cap=opt.pwws_test_num)
def get_adv_by_convex_syn(self, embd_p, embd_h, y, x_p_text_like_syn, x_p_syn_embd, x_p_syn_valid, x_h_text_like_syn, x_h_syn_embd, x_h_syn_valid, x_p_mask, x_h_mask, attack_type_dict): #noted that if attack hypo only then the output x_p_comb_p is meaningless # record context self_training_context = self.training # set context if self.eval_adv_mode: self.eval() else: self.train() device = embd_p.device # get param of attacks num_steps = attack_type_dict['num_steps'] loss_func = attack_type_dict['loss_func'] w_optm_lr = attack_type_dict['w_optm_lr'] sparse_weight = attack_type_dict['sparse_weight'] out_type = attack_type_dict['out_type'] attack_hypo_only = attack_type_dict[ 'attack_hypo_only'] if 'attack_hypo_only' in attack_type_dict else True batch_size, text_len, embd_dim = embd_p.shape batch_size, text_len, syn_num, embd_dim = x_p_syn_embd.shape w_p = torch.empty(batch_size, text_len, syn_num, 1).to(device).to(embd_p.dtype) w_h = torch.empty(batch_size, text_len, syn_num, 1).to(device).to(embd_p.dtype) #ww = torch.zeros(batch_size, text_len, syn_num, 1).to(device).to(embd.dtype) #ww = ww+500*(syn_valid.reshape(batch_size, text_len, syn_num, 1)-1) nn.init.kaiming_normal_(w_p) nn.init.kaiming_normal_(w_h) w_p.requires_grad_() w_h.requires_grad_() import utils params = [w_p, w_h] optimizer = utils.getOptimizer(params, name='adam', lr=w_optm_lr, weight_decay=2e-5) def get_comb_p(w, syn_valid): ww = w * syn_valid.reshape( batch_size, text_len, syn_num, 1) + 10000 * ( syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1) return F.softmax(ww, -2) def get_comb_ww(w, syn_valid): ww = w * syn_valid.reshape( batch_size, text_len, syn_num, 1) + 10000 * ( syn_valid.reshape(batch_size, text_len, syn_num, 1) - 1) return ww def get_comb(p, syn): return (p * syn.detach()).sum(-2) embd_p_ori = embd_p.detach() embd_h_ori = embd_h.detach() logit_ori = self.embd_to_logit(embd_p_ori, embd_h_ori, x_p_mask, x_h_mask) for _ in range(num_steps): optimizer.zero_grad() with torch.enable_grad(): ww_p = get_comb_ww(w_p, x_p_syn_valid) ww_h = get_comb_ww(w_h, x_h_syn_valid) #comb_p = get_comb_p(w, syn_valid) embd_p_adv = get_comb(F.softmax(ww_p, -2), x_p_syn_embd) embd_h_adv = get_comb(F.softmax(ww_h, -2), x_h_syn_embd) if attack_hypo_only: logit_adv = self.embd_to_logit(embd_p_ori, embd_h_adv, x_p_mask, x_h_mask) else: logit_adv = self.embd_to_logit(embd_p_adv, embd_h_adv, x_p_mask, x_h_mask) if loss_func == 'ce': loss = -F.cross_entropy(logit_adv, y, reduction='sum') elif loss_func == 'kl': criterion_kl = nn.KLDivLoss(reduction="sum") loss = -criterion_kl(F.log_softmax(logit_adv, dim=1), F.softmax(logit_ori.detach(), dim=1)) #print("ad loss:", loss.data.item()) if sparse_weight != 0: #loss_sparse = (comb_p*comb_p).mean() if attack_hypo_only: loss_sparse = (-F.softmax(ww_h, -2) * F.log_softmax(ww_h, -2)).sum(-2).mean() else: loss_sparse = ( (-F.softmax(ww_p, -2) * F.log_softmax(ww_p, -2)).sum(-2).mean() + (-F.softmax(ww_h, -2) * F.log_softmax(ww_h, -2)).sum(-2).mean()) / 2 #loss -= sparse_weight*loss_sparse loss = loss + sparse_weight * loss_sparse #print(loss_sparse.data.item()) #loss*=1000 loss.backward() optimizer.step() #print((ww-w).max()) x_p_comb_p = get_comb_p(w_p, x_p_syn_valid) x_h_comb_p = get_comb_p(w_h, x_h_syn_valid) """ out = get_comb(comb_p, syn) delta = (out-embd_ori).reshape(batch_size*text_len,embd_dim) delta = F.pairwise_distance(delta, torch.zeros_like(delta), p=2.0) valid = (delta>0.01).to(device).to(delta.dtype) delta = (valid*delta).sum()/valid.sum() print("mean l2 dis between embd and embd_adv:", delta.data.item()) #print("mean max comb_p:", (comb_p.max(-2)[0]).mean().data.item()) """ # resume context if self_training_context == True: self.train() else: self.eval() if out_type == "comb_p": return x_p_comb_p.detach(), x_h_comb_p.detach() elif out_type == "text": assert (x_p_text_like_syn is not None) # n l synlen assert (x_h_text_like_syn is not None) # n l synlen x_p_comb_p = x_p_comb_p.reshape(batch_size * text_len, syn_num) x_h_comb_p = x_h_comb_p.reshape(batch_size * text_len, syn_num) ind_x_p = x_p_comb_p.max(-1)[1] # shape batch_size* text_len ind_x_h = x_h_comb_p.max(-1)[1] # shape batch_size* text_len adv_text_x_p = (x_p_text_like_syn.reshape( batch_size * text_len, syn_num)[np.arange(batch_size * text_len), ind_x_p]).reshape(batch_size, text_len) adv_text_x_h = (x_h_text_like_syn.reshape( batch_size * text_len, syn_num)[np.arange(batch_size * text_len), ind_x_h]).reshape(batch_size, text_len) return adv_text_x_p, adv_text_x_h