def eval_iters(ae_model, dis_model): eval_data_loader = non_pair_data_loader( batch_size=1, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) eval_file_list = [ args.data_path + 'sentiment.test.0', args.data_path + 'sentiment.test.1', ] eval_label_list = [ [0], [1], ] eval_data_loader.create_batches(eval_file_list, eval_label_list, if_shuffle=False) gold_ans = load_human_answer(args.data_path) assert len(gold_ans) == eval_data_loader.num_batch add_log("Start eval process.") ae_model.eval() dis_model.eval() for it in range(eval_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch() print("------------%d------------" % it) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) print("origin_labels", tensor_labels) latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=args.id_bos) print(id2text_sentence(generator_text[0], args.id_to_word)) # Define target label target = get_cuda(torch.tensor([[1.0]], dtype=torch.float)) if tensor_labels[0].item() > 0.5: target = get_cuda(torch.tensor([[0.0]], dtype=torch.float)) print("target_labels", target) modify_text = fgim_attack(dis_model, latent, target, ae_model, args.max_sequence_length, args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it]) add_output(modify_text) output_text = str(it) + ":\ngold: " + id2text_sentence( gold_ans[it], args.id_to_word) + "\nmodified: " + modify_text add_output(output_text) add_result( str(it) + ":\n" + str( calc_bleu(id2text_sentence(gold_ans[it], args.id_to_word), modify_text))) return
def fgim_step(model, origin_data, target, ae_model, max_sequence_length, id_bos, id2text_sentence, id_to_word, gold_ans, epsilon, step): """Fast Gradient Iterative Methods""" dis_criterion = nn.BCELoss(size_average=True) it = 0 data = origin_data while it < step: data = to_var(data.clone()) # (batch_size, seq_length, latent_size) # Set requires_grad attribute of tensor. Important for Attack data.requires_grad = True output = model.forward(data) loss = dis_criterion(output, target) model.zero_grad() loss.backward() data_grad = data.grad.data data = data - epsilon * data_grad it += 1 epsilon = epsilon * 0.9 generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = id2text_sentence(generator_id[0], id_to_word) print("| It {:2d} | dis model pred {:5.4f} |".format( it, output[0].item())) print(generator_text) return data, generator_id
def plot_tsne(ae_model, dis_model, epsilon=2, step=0): eval_data_loader = non_pair_data_loader( batch_size=500, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) eval_file_list = [ args.data_path + 'sentiment.test.0', args.data_path + 'sentiment.test.1', ] eval_label_list = [ [0], [1], ] eval_data_loader.create_batches(eval_file_list, eval_label_list, if_shuffle=False) gold_ans = load_human_answer(args.data_path) ae_model.eval() dis_model.eval() latents, labels = [], [] it = 0 for _ in range(eval_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch() print("------------%d------------" % it) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) print("origin_labels", tensor_labels[0].item()) latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # Define target label target = get_cuda( torch.ones((tensor_labels.size(0), 1), dtype=torch.float)) target = target - tensor_labels if step > 0: latent, modified_text = fgim_step(dis_model, latent, target, ae_model, args.max_sequence_length, args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it], epsilon, step) latents.append(latent) labels.append(tensor_labels) it += tensor_labels.size(0) latents = torch.cat(latents, dim=0).detach().cpu().numpy() labels = torch.cat(labels, dim=0).squeeze().detach().cpu().numpy() tsne_plot_representation(latents, labels, f"tsne_step{step}_eps{epsilon}")
def plot_tsne(ae_model, dis_model): epsilon = 2 step = 1 eval_data_loader = non_pair_data_loader( batch_size=1, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) eval_file_list = [ args.data_path + 'sentiment.test.0', args.data_path + 'sentiment.test.1', ] eval_label_list = [ [0], [1], ] eval_data_loader.create_batches(eval_file_list, eval_label_list, if_shuffle=False) gold_ans = load_human_answer(args.data_path) assert len(gold_ans) == eval_data_loader.num_batch ae_model.eval() dis_model.eval() latents, labels = [], [] for it in range(eval_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch() print("------------%d------------" % it) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) print("origin_labels", tensor_labels.item()) latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # Define target label target = get_cuda(torch.tensor([[1.0]], dtype=torch.float)) if tensor_labels[0].item() > 0.5: target = get_cuda(torch.tensor([[0.0]], dtype=torch.float)) modified_latent, modified_text = fgim_step( dis_model, latent, target, ae_model, args.max_sequence_length, args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it], epsilon, step) latents.append(modified_latent) labels.append(tensor_labels.item()) latents = torch.cat(latents, dim=0).detach().cpu().numpy() labels = numpy.array(labels) tsne_plot_representation(latents, labels)
def predict(args, ae_model, dis_model, batch, epsilon): (batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens) = batch ae_model.eval() dis_model.eval() latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) generator_text = ae_model.greedy_decode(latent, max_len=args.max_sequence_length, start_id=args.id_bos) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) print(id2text_sentence(generator_text[0], args.id_to_word)) target = get_cuda(torch.tensor([[1.0]], dtype=torch.float)) if tensor_labels[0].item() > 0.5: target = get_cuda(torch.tensor([[0.0]], dtype=torch.float)) dis_criterion = nn.BCELoss(size_average=True) data = to_var(latent.clone()) # (batch_size, seq_length, latent_size) data.requires_grad = True output = dis_model.forward(data) loss = dis_criterion(output, target) dis_model.zero_grad() loss.backward() data_grad = data.grad.data data = data - epsilon * data_grad generator_id = ae_model.greedy_decode(data, max_len=args.max_sequence_length, start_id=args.id_bos) return id2text_sentence(generator_id[0], args.id_to_word)
def train_iters(ae_model, dis_model): train_data_loader = non_pair_data_loader( batch_size=args.batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) train_data_loader.create_batches(args.train_file_list, args.train_label_list, if_shuffle=True) add_log("Start train process.") ae_model.train() dis_model.train() ae_optimizer = NoamOpt( ae_model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001) ae_criterion = get_cuda( LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1)) dis_criterion = nn.BCELoss(size_average=True) for epoch in range(200): print('-' * 94) epoch_start_time = time.time() for it in range(train_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch() # For debug # print(batch_sentences[0]) # print(tensor_src[0]) # print(tensor_src_mask[0]) # print("tensor_src_mask", tensor_src_mask.size()) # print(tensor_tgt[0]) # print(tensor_tgt_y[0]) # print(tensor_tgt_mask[0]) # print(batch_ntokens) # Forward pass latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # print(latent.size()) # (batch_size, max_src_seq, d_model) # print(out.size()) # (batch_size, max_tgt_seq, vocab_size) # Loss calculation loss_rec = ae_criterion( out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data # loss_all = loss_rec + loss_dis ae_optimizer.optimizer.zero_grad() loss_rec.backward() ae_optimizer.step() # Classifier dis_lop = dis_model.forward(to_var(latent.clone())) loss_dis = dis_criterion(dis_lop, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward() dis_optimizer.step() if it % 200 == 0: add_log( '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |' .format(epoch, it, train_data_loader.num_batch, loss_rec, loss_dis)) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=args.id_bos) print(id2text_sentence(generator_text[0], args.id_to_word)) add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) # Save model torch.save(ae_model.state_dict(), args.current_save_path + 'ae_model_params.pkl') torch.save(dis_model.state_dict(), args.current_save_path + 'dis_model_params.pkl') return
def fgim(data_loader, args, ae_model, c_theta, gold_ans = None) : """ Input: Original latent representation z : (n_batch, batch_size, seq_length, latent_size) Well-trained attribute classifier C_θ Target attribute y A set of weights w = {w_i} Decay coefficient λ Threshold t Output: An optimal modified latent representation z' """ w = args.w lambda_ = args.lambda_ t = args.threshold max_iter_per_epsilon = args.max_iter_per_epsilon max_sequence_length = args.max_sequence_length id_bos = args.id_bos id_to_word = args.id_to_word limit_batches = args.limit_batches text_z_prime = {} text_z_prime = {"source" : [], "origin_labels" : [], "before" : [], "after" : [], "change" : [], "pred_label" : []} if gold_ans is not None : text_z_prime["gold_ans"] = [] z_prime = [] dis_criterion = nn.BCELoss(size_average=True) n_batches = 0 for it in tqdm.tqdm(list(range(data_loader.num_batch)), desc="FGIM"): if gold_ans is not None : text_z_prime["gold_ans"].append(gold_ans[it]) _, tensor_labels, \ tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, _ = data_loader.next_batch() # only on negative example negative_examples = ~(tensor_labels.squeeze()==args.positive_label) tensor_labels = tensor_labels[negative_examples].squeeze(0) # .view(1, -1) tensor_src = tensor_src[negative_examples].squeeze(0) tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze(0) tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0) tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0) tensor_tgt = tensor_tgt[negative_examples].squeeze(0) tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0) #if gold_ans is not None : # text_z_prime["gold_ans"][-1] = text_z_prime["gold_ans"][-1][negative_examples] #print("------------%d------------" % it) if negative_examples.any(): text_z_prime["source"].append([id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y]) text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy()) origin_data, _ = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask) # Define target label y_prime = 1.0 - (tensor_labels > 0.5).float() ############################### FGIM ###################################################### generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos) generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id] text_z_prime["before"].append(generator_text) flag = False for w_i in w: #print("---------- w_i:", w_i) data = to_var(origin_data.clone()) # (batch_size, seq_length, latent_size) b = True if b : data.requires_grad = True output = c_theta.forward(data) loss = dis_criterion(output, y_prime) c_theta.zero_grad() loss.backward() data = data - w_i * data.grad.data else : data = origin_data output = c_theta.forward(data) it = 0 while True: #if torch.cdist(output, y_prime) < t : #if torch.sum((output - y_prime)**2, dim=1).sqrt().mean() < t : if torch.sum((output - y_prime).abs(), dim=1).mean() < t : flag = True break data = to_var(data.clone()) # (batch_size, seq_length, latent_size) # Set requires_grad attribute of tensor. Important for Attack data.requires_grad = True output = c_theta.forward(data) # Calculate gradients of model in backward pass loss = dis_criterion(output, y_prime) c_theta.zero_grad() # dis_optimizer.zero_grad() loss.backward() data = data - w_i * data.grad.data it += 1 # data = perturbed_data w_i = lambda_ * w_i if False : if text_gen_params is not None : generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = id2text_sentence(generator_id[0], id_to_word) print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item())) print(generator_text) if it > max_iter_per_epsilon: break if flag : z_prime.append(data) generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id] text_z_prime["after"].append(generator_text) text_z_prime["change"].append([True]*len(output)) text_z_prime["pred_label"].append([o.item() for o in output]) break if not flag : # cannot debiaising z_prime.append(origin_data) text_z_prime["after"].append(text_z_prime["before"][-1]) text_z_prime["change"].append([False]*len(y_prime)) text_z_prime["pred_label"].append([o.item() for o in y_prime]) n_batches += 1 if n_batches > limit_batches: break return z_prime, text_z_prime
def fgim_attack(model, origin_data, target, ae_model, max_sequence_length, id_bos, id2text_sentence, id_to_word, gold_ans = None): """Fast Gradient Iterative Methods""" dis_criterion = nn.BCELoss(size_average=True) t = 0.001 # Threshold lambda_ = 0.9 # Decay coefficient max_iter_per_epsilon=20 if gold_ans is not None : gold_text = id2text_sentence(gold_ans, id_to_word) print("gold:", gold_text) flag = False for epsilon in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]: print("---------- epsilon:", epsilon) generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos) generator_text = id2text_sentence(generator_id[0], id_to_word) print("z:", generator_text) data = to_var(origin_data.clone()) # (batch_size, seq_length, latent_size) b = True if b : data.requires_grad = True output = model.forward(data) loss = dis_criterion(output, target) model.zero_grad() loss.backward() data_grad = data.grad.data data = data - epsilon * data_grad else : data = origin_data it = 0 while True: if torch.cdist(output, target) < t : flag = True break data = to_var(data.clone()) # (batch_size, seq_length, latent_size) # Set requires_grad attribute of tensor. Important for Attack data.requires_grad = True output = model.forward(data) # Calculate gradients of model in backward pass loss = dis_criterion(output, target) model.zero_grad() # dis_optimizer.zero_grad() loss.backward() data_grad = data.grad.data data = data - epsilon * data_grad it += 1 # data = perturbed_data epsilon = epsilon * lambda_ if False : generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = id2text_sentence(generator_id[0], id_to_word) print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item())) print(generator_text) if it > max_iter_per_epsilon: break generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = id2text_sentence(generator_id[0], id_to_word) print("|dis model pred {:5.4f} |".format(output[0].item())) print("z*", generator_text) print() if flag : return generator_text
def eval_iters(ae_model, dis_model): # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True) if args.use_albert: tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny", do_lower_case=True) elif args.use_tiny_bert: tokenizer = AutoTokenizer.from_pretrained( "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True) elif args.use_distil_bert: tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-uncased', do_lower_case=True) tokenizer.add_tokens('[EOS]') bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0] ae_model.bert_encoder.resize_token_embeddings(len(tokenizer)) print("[CLS] ID: ", bos_id) # if args.task == 'news_china_taiwan': eval_file_list = [ args.data_path + 'test.0', args.data_path + 'test.1', ] eval_label_list = [ [0], [1], ] if args.eval_positive: eval_file_list = eval_file_list[::-1] eval_label_list = eval_label_list[::-1] print("Load testData...") testData = TextDataset(batch_size=args.batch_size, id_bos='[CLS]', id_eos='[EOS]', id_unk='[UNK]', max_sequence_length=args.max_sequence_length, vocab_size=0, file_list=eval_file_list, label_list=eval_label_list, tokenizer=tokenizer) dataset = testData eval_data_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn, num_workers=4) num_batch = len(eval_data_loader) trange = tqdm(enumerate(eval_data_loader), total=num_batch, desc='Training', file=sys.stdout, position=0, leave=True) gold_ans = [''] * num_batch add_log("Start eval process.") ae_model.to(device) dis_model.to(device) ae_model.eval() dis_model.eval() total_latent_lst = [] for it, data in trange: batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data tensor_labels = tensor_labels.to(device) tensor_src = tensor_src.to(device) tensor_tgt = tensor_tgt.to(device) tensor_tgt_y = tensor_tgt_y.to(device) tensor_src_mask = tensor_src_mask.to(device) tensor_tgt_mask = tensor_tgt_mask.to(device) print("------------%d------------" % it) print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task)) print("origin_labels", tensor_labels.cpu().detach().numpy()[0]) latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=bos_id) print(id2text_sentence(generator_text[0], tokenizer, args.task)) # Define target label target = torch.FloatTensor([[1.0]]).to(device) if tensor_labels[0].item() > 0.5: target = torch.FloatTensor([[0.0]]).to(device) print("target_labels", target) modify_text, latent_lst = fgim_attack(dis_model, latent, target, ae_model, args.max_sequence_length, bos_id, id2text_sentence, None, gold_ans[it], tokenizer, device, task=args.task, save_latent=args.save_latent) if args.save_latent != -1: total_latent_lst.append(latent_lst) add_output(modify_text) if it >= args.save_latent_num: break print("Save log in ", args.output_file) if args.save_latent == -1: return folder = './latent_{}/'.format(args.task) if not os.path.exists(folder): os.mkdir(folder) if args.save_latent == 0: # full prefix = 'full' elif args.save_latent == 1: # first 6 layer prefix = 'first_6' elif args.save_latent == 2: # last 6 layer prefix = 'last_6' elif args.save_latent == 3: # get second layer prefix = 'distill_2' total_latent_lst = np.asarray(total_latent_lst) if args.eval_negative: save_label = 0 else: save_label = 1 with open(folder + '{}_{}.pkl'.format(prefix, save_label), 'wb') as f: pickle.dump(total_latent_lst, f) print("Save laten in ", folder + '{}_{}.pkl'.format(prefix, save_label))
def train_iters(ae_model, dis_model): if args.use_albert: tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny", do_lower_case=True) elif args.use_tiny_bert: tokenizer = AutoTokenizer.from_pretrained( "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True) elif args.use_distil_bert: tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-uncased', do_lower_case=True) # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True) tokenizer.add_tokens('[EOS]') bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0] ae_model.bert_encoder.resize_token_embeddings(len(tokenizer)) #print("[CLS] ID: ", bos_id) print("Load trainData...") if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format( args.task)): with open('./{}_trainData.pkl'.format(args.task), 'rb') as f: trainData = pickle.load(f) else: trainData = TextDataset(batch_size=args.batch_size, id_bos='[CLS]', id_eos='[EOS]', id_unk='[UNK]', max_sequence_length=args.max_sequence_length, vocab_size=0, file_list=args.train_file_list, label_list=args.train_label_list, tokenizer=tokenizer) with open('./{}_trainData.pkl'.format(args.task), 'wb') as f: pickle.dump(trainData, f) add_log("Start train process.") ae_model.train() dis_model.train() ae_model.to(device) dis_model.to(device) ''' Fixing or distilling BERT encoder ''' if args.fix_first_6: print("Try fixing first 6 bertlayers") for layer in range(6): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False elif args.fix_last_6: print("Try fixing last 6 bertlayers") for layer in range(6, 12): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False if args.distill_2: print("Get result from layer 2") for layer in range(2, 12): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False ae_optimizer = NoamOpt( ae_model.d_model, 1, 2000, torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001) #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1)) ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size, padding_idx=0, smoothing=0.1).to(device) dis_criterion = nn.BCELoss(reduction='mean') history = {'train': []} for epoch in range(args.epochs): print('-' * 94) epoch_start_time = time.time() total_rec_loss = 0 total_dis_loss = 0 train_data_loader = DataLoader(trainData, batch_size=args.batch_size, shuffle=True, collate_fn=trainData.collate_fn, num_workers=4) num_batch = len(train_data_loader) trange = tqdm(enumerate(train_data_loader), total=num_batch, desc='Training', file=sys.stdout, position=0, leave=True) for it, data in trange: batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data tensor_labels = tensor_labels.to(device) tensor_src = tensor_src.to(device) tensor_tgt = tensor_tgt.to(device) tensor_tgt_y = tensor_tgt_y.to(device) tensor_src_mask = tensor_src_mask.to(device) tensor_tgt_mask = tensor_tgt_mask.to(device) # Forward pass latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # Loss calculation loss_rec = ae_criterion( out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data ae_optimizer.optimizer.zero_grad() loss_rec.backward() ae_optimizer.step() latent = latent.detach() next_latent = latent.to(device) # Classifier dis_lop = dis_model.forward(next_latent) loss_dis = dis_criterion(dis_lop, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward() dis_optimizer.step() total_rec_loss += loss_rec.item() total_dis_loss += loss_dis.item() trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1), total_dis_loss=total_dis_loss / (it + 1)) if it % 100 == 0: add_log( '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |' .format(epoch, it, num_batch, loss_rec, loss_dis)) print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task)) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=bos_id) print(id2text_sentence(generator_text[0], tokenizer, args.task)) # Save model #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl') #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl') history['train'].append({ 'epoch': epoch, 'total_rec_loss': total_rec_loss / len(trange), 'total_dis_loss': total_dis_loss / len(trange) }) add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) # Save model torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl') torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl') print("Save in ", args.current_save_path) return
def sedat_eval(args, ae_model, f, deb): """ Input: Original latent representation z : (n_batch, batch_size, seq_length, latent_size) Output: An optimal modified latent representation z' """ max_sequence_length = args.max_sequence_length id_bos = args.id_bos id_to_word = args.id_to_word limit_batches = args.limit_batches eval_data_loader = non_pair_data_loader( batch_size=args.batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) file_list = [args.test_data_file] eval_data_loader.create_batches(args, file_list, if_shuffle=False, n_samples=args.test_n_samples) if args.references_files: gold_ans = load_human_answer(args.references_files, args.text_column) assert len(gold_ans) == eval_data_loader.num_batch else: gold_ans = None add_log(args, "Start eval process.") ae_model.eval() f.eval() deb.eval() text_z_prime = {} text_z_prime = { "source": [], "origin_labels": [], "before": [], "after": [], "change": [], "pred_label": [] } if gold_ans is not None: text_z_prime["gold_ans"] = [] z_prime = [] n_batches = 0 for it in tqdm.tqdm(list(range(eval_data_loader.num_batch)), desc="SEDAT"): _, tensor_labels, \ tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, _ = eval_data_loader.next_batch() # only on negative example negative_examples = ~(tensor_labels.squeeze() == args.positive_label) tensor_labels = tensor_labels[negative_examples].squeeze( 0) # .view(1, -1) tensor_src = tensor_src[negative_examples].squeeze(0) tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0) tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze( 0) tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0) tensor_tgt = tensor_tgt[negative_examples].squeeze(0) tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0) if negative_examples.any(): if gold_ans is not None: text_z_prime["gold_ans"].append(gold_ans[it]) text_z_prime["source"].append( [id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y]) text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy()) origin_data, _ = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask) generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos) generator_text = [ id2text_sentence(gid, id_to_word) for gid in generator_id ] text_z_prime["before"].append(generator_text) data = deb(origin_data, mask=None) data = torch.sum(ae_model.sigmoid(data), dim=1) # (batch_size, d_model) #logit = ae_model.decode(data.unsqueeze(1), tensor_tgt, tensor_tgt_mask) # (batch_size, max_tgt_seq, d_model) #output = ae_model.generator(logit) # (batch_size, max_seq, vocab_size) y_hat = f.forward(data) y_hat = y_hat.round().int() z_prime.append(data) generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos) generator_text = [ id2text_sentence(gid, id_to_word) for gid in generator_id ] text_z_prime["after"].append(generator_text) text_z_prime["change"].append([True] * len(y_hat)) text_z_prime["pred_label"].append([y_.item() for y_ in y_hat]) n_batches += 1 if n_batches > limit_batches: break write_text_z_in_file(args, text_z_prime) add_log(args, "") add_log(args, "Saving model modify embedding %s ..." % args.current_save_path) torch.save(z_prime, os.path.join(args.current_save_path, 'z_prime_sedat.pkl')) return z_prime, text_z_prime
def train_step(args, data_loader, ae_model, dis_model, ae_optimizer, dis_optimizer, ae_criterion, dis_criterion, epoch): ae_model.train() dis_model.train() loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0 loss_clf, total_clf, n_valid_clf = 0, 0, 0 epoch_start_time = time.time() for it in range(data_loader.num_batch): flag_rec = True batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = data_loader.next_batch() # Forward pass latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask) # Loss calculation if not args.sedat: loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / ( tensor_ntokens.data + eps) else: # only on positive example positive_examples = tensor_labels.squeeze() == args.positive_label out = out[positive_examples] # or out[positive_examples,:,:] tensor_tgt_y = tensor_tgt_y[ positive_examples] # or tensor_tgt_y[positive_examples,:] tensor_ntokens = (tensor_tgt_y != 0).data.sum().float() loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / ( tensor_ntokens.data + eps) flag_rec = positive_examples.any() out = out.squeeze(0) tensor_tgt_y = tensor_tgt_y.squeeze(0) if flag_rec: n_v, n_w = get_n_v_w(tensor_tgt_y, out) else: n_w = float("nan") n_v = float("nan") x_e = loss_rec.item() * n_w loss_ae += loss_rec.item() n_words_ae += n_w xe_loss_ae += x_e n_valid_ae += n_v ae_acc = 100. * n_v / (n_w + eps) avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps) avg_ae_loss = loss_ae / (it + 1) ae_ppl = np.exp(x_e / (n_w + eps)) avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps)) ae_optimizer.zero_grad() loss_rec.backward(retain_graph=not args.detach_classif) ae_optimizer.step() # Classifier if args.detach_classif: dis_lop = dis_model.forward(to_var(latent.clone())) else: dis_lop = dis_model.forward(latent) loss_dis = dis_criterion(dis_lop, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward() dis_optimizer.step() t_c = tensor_labels.view(-1).size(0) n_v = (dis_lop.round().int() == tensor_labels).sum().item() loss_clf += loss_dis.item() total_clf += t_c n_valid_clf += n_v clf_acc = 100. * n_v / (t_c + eps) avg_clf_acc = 100. * n_valid_clf / (total_clf + eps) avg_clf_loss = loss_clf / (it + 1) if it % args.log_interval == 0: add_log( args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format( epoch, it, data_loader.num_batch)) add_log( args, 'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |' .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc, loss_dis.item())) add_log( args, 'Train, avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |' .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl, avg_clf_acc, avg_clf_loss)) if flag_rec: i = random.randint(0, len(tensor_tgt_y) - 1) reference = id2text_sentence(tensor_tgt_y[i], args.id_to_word) add_log(args, "input : %s" % reference) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=args.id_bos) # batch_sentences hypothesis = id2text_sentence(generator_text[i], args.id_to_word) add_log(args, "gen : %s" % hypothesis) add_log( args, "bleu : %s" % calc_bleu(reference.split(" "), hypothesis.split(" "))) s = {} L = data_loader.num_batch + eps s["train_ae_loss"] = loss_ae / L s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps) s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps)) s["train_clf_loss"] = loss_clf / L s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps) add_log( args, '| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) add_log( args, '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |' .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"], s["train_clf_acc"], s["train_clf_loss"])) return s
def sedat_train(args, ae_model, f, deb): """ Input: Original latent representation z : (n_batch, batch_size, seq_length, latent_size) Output: An optimal modified latent representation z' """ # TODO : fin a metric to control the evelotuion of training, mainly for deb model lambda_ = args.sedat_threshold alpha, beta = [float(coef) for coef in args.sedat_alpha_beta.split(",")] # only on negative example only_on_negative_example = args.sedat_only_on_negative_example penalty = args.penalty type_penalty = args.type_penalty assert penalty in ["lasso", "ridge"] assert type_penalty in ["last", "group"] train_data_loader = non_pair_data_loader( batch_size=args.batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) file_list = [args.train_data_file] if os.path.exists(args.val_data_file): file_list.append(args.val_data_file) train_data_loader.create_batches(args, file_list, if_shuffle=True, n_samples=args.train_n_samples) add_log(args, "Start train process.") #add_log("Start train process.") ae_model.train() f.train() deb.train() ae_optimizer = get_optimizer(parameters=ae_model.parameters(), s=args.ae_optimizer, noamopt=args.ae_noamopt) dis_optimizer = get_optimizer(parameters=f.parameters(), s=args.dis_optimizer) deb_optimizer = get_optimizer(parameters=deb.parameters(), s=args.dis_optimizer) ae_criterion = get_cuda( LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1), args) dis_criterion = nn.BCELoss(size_average=True) deb_criterion = LossSedat(penalty=penalty) stats = [] for epoch in range(args.max_epochs): print('-' * 94) epoch_start_time = time.time() loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0 loss_clf, total_clf, n_valid_clf = 0, 0, 0 for it in range(train_data_loader.num_batch): _, tensor_labels, \ tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, _ = train_data_loader.next_batch() flag = True # only on negative example if only_on_negative_example: negative_examples = ~(tensor_labels.squeeze() == args.positive_label) tensor_labels = tensor_labels[negative_examples].squeeze( 0) # .view(1, -1) tensor_src = tensor_src[negative_examples].squeeze(0) tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0) tensor_src_attn_mask = tensor_src_attn_mask[ negative_examples].squeeze(0) tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0) tensor_tgt = tensor_tgt[negative_examples].squeeze(0) tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0) flag = negative_examples.any() if flag: # forward z, out, z_list = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask, return_intermediate=True) #y_hat = f.forward(to_var(z.clone())) y_hat = f.forward(z) loss_dis = dis_criterion(y_hat, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward(retain_graph=True) dis_optimizer.step() dis_lop = f.forward(z) t_c = tensor_labels.view(-1).size(0) n_v = (dis_lop.round().int() == tensor_labels).sum().item() loss_clf += loss_dis.item() total_clf += t_c n_valid_clf += n_v clf_acc = 100. * n_v / (t_c + eps) avg_clf_acc = 100. * n_valid_clf / (total_clf + eps) avg_clf_loss = loss_clf / (it + 1) mask_deb = y_hat.squeeze( ) >= lambda_ if args.positive_label == 0 else y_hat.squeeze( ) < lambda_ # if f(z) > lambda : if mask_deb.any(): y_hat_deb = y_hat[mask_deb] if type_penalty == "last": z_deb = z[mask_deb].squeeze( 0) if args.batch_size == 1 else z[mask_deb] elif type_penalty == "group": # TODO : unit test for bach_size = 1 z_deb = z_list[-1][mask_deb] z_prime, z_prime_list = deb(z_deb, mask=None, return_intermediate=True) if type_penalty == "last": z_prime = torch.sum(ae_model.sigmoid(z_prime), dim=1) loss_deb = alpha * deb_criterion( z_deb, z_prime, is_list=False) + beta * y_hat_deb.sum() elif type_penalty == "group": z_deb_list = [z_[mask_deb] for z_ in z_list] #assert len(z_deb_list) == len(z_prime_list) loss_deb = alpha * deb_criterion( z_deb_list, z_prime_list, is_list=True) + beta * y_hat_deb.sum() deb_optimizer.zero_grad() loss_deb.backward(retain_graph=True) deb_optimizer.step() else: loss_deb = torch.tensor(float("nan")) # else : if (~mask_deb).any(): out_ = out[~mask_deb] tensor_tgt_y_ = tensor_tgt_y[~mask_deb] tensor_ntokens = (tensor_tgt_y_ != 0).data.sum().float() loss_rec = ae_criterion( out_.contiguous().view(-1, out_.size(-1)), tensor_tgt_y_.contiguous().view(-1)) / ( tensor_ntokens.data + eps) else: loss_rec = torch.tensor(float("nan")) ae_optimizer.zero_grad() (loss_dis + loss_deb + loss_rec).backward() ae_optimizer.step() if True: n_v, n_w = get_n_v_w(tensor_tgt_y, out) else: n_w = float("nan") n_v = float("nan") x_e = loss_rec.item() * n_w loss_ae += loss_rec.item() n_words_ae += n_w xe_loss_ae += x_e n_valid_ae += n_v ae_acc = 100. * n_v / (n_w + eps) avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps) avg_ae_loss = loss_ae / (it + 1) ae_ppl = np.exp(x_e / (n_w + eps)) avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps)) x_e = loss_rec.item() * n_w loss_ae += loss_rec.item() n_words_ae += n_w xe_loss_ae += x_e n_valid_ae += n_v if it % args.log_interval == 0: add_log(args, "") add_log( args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format( epoch, it, train_data_loader.num_batch)) add_log( args, 'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |' .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc, loss_dis.item())) add_log( args, 'Train : avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | diss loss {:5.4f} |' .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl, avg_clf_acc, avg_clf_loss)) add_log( args, "input : %s" % id2text_sentence(tensor_tgt_y[0], args.id_to_word)) generator_text = ae_model.greedy_decode( z, max_len=args.max_sequence_length, start_id=args.id_bos) # batch_sentences add_log( args, "gen : %s" % id2text_sentence(generator_text[0], args.id_to_word)) if mask_deb.any(): generator_text_prime = ae_model.greedy_decode( z_prime, max_len=args.max_sequence_length, start_id=args.id_bos) add_log( args, "deb : %s" % id2text_sentence( generator_text_prime[0], args.id_to_word)) s = {} L = train_data_loader.num_batch + eps s["train_ae_loss"] = loss_ae / L s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps) s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps)) s["train_clf_loss"] = loss_clf / L s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps) stats.append(s) add_log(args, "") add_log( args, '| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) add_log( args, '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |' .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"], s["train_clf_acc"], s["train_clf_loss"])) # Save model torch.save( ae_model.state_dict(), os.path.join(args.current_save_path, 'ae_model_params_deb.pkl')) torch.save( f.state_dict(), os.path.join(args.current_save_path, 'dis_model_params_deb.pkl')) torch.save( deb.state_dict(), os.path.join(args.current_save_path, 'deb_model_params_deb.pkl')) add_log(args, "Saving training statistics %s ..." % args.current_save_path) torch.save(stats, os.path.join(args.current_save_path, 'stats_train_deb.pkl'))
def fgim_algorithm(args, ae_model, dis_model): batch_size = 1 test_data_loader = non_pair_data_loader( batch_size=batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) file_list = [args.test_data_file] test_data_loader.create_batches(args, file_list, if_shuffle=False, n_samples=args.test_n_samples) if args.references_files: gold_ans = load_human_answer(args.references_files, args.text_column) assert len(gold_ans) == test_data_loader.num_batch else: gold_ans = [[None] * batch_size] * test_data_loader.num_batch add_log(args, "Start eval process.") ae_model.eval() dis_model.eval() fgim_our = True if fgim_our: # for FGIM z_prime, text_z_prime = fgim(test_data_loader, args, ae_model, dis_model, gold_ans=gold_ans) write_text_z_in_file(args, text_z_prime) add_log( args, "Saving model modify embedding %s ..." % args.current_save_path) torch.save(z_prime, os.path.join(args.current_save_path, 'z_prime_fgim.pkl')) else: for it in range(test_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = test_data_loader.next_batch() print("------------%d------------" % it) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) print("origin_labels", tensor_labels) latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=args.id_bos) print(id2text_sentence(generator_text[0], args.id_to_word)) # Define target label target = get_cuda(torch.tensor([[1.0]], dtype=torch.float), args) if tensor_labels[0].item() > 0.5: target = get_cuda(torch.tensor([[0.0]], dtype=torch.float), args) add_log(args, "target_labels : %s" % target) modify_text = fgim_attack(dis_model, latent, target, ae_model, args.max_sequence_length, args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it]) add_output(args, modify_text)