def mle_train(batches, opt, do_train=True, do_log=True, do_adv=False, adv_config=None): for m in m_dict: if do_train == True: m_dict[m].train() else: m_dict[m].eval() if do_adv == True: inf_adv_batches = adv_config['inf_adv_batches'] all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum = 0, 0, 0, 0, 0 b_count, adv_b_count = 0, 0 loss_sen = [] for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches: #print src_w[0], src_w[1]; sys.exit(1) loss = 0 b_count = b_count + 1 bz = src_mb.size(0) all_num = all_num + sum(tgt_len) batch_logpdf = models.encoder_decoder_forward(src_mb, tgt_mb, tgt_len, m_dict, adv_config) if 1 == 0: #===drawing debug=== logprob_lis = [] w_lis = [] for i in range(bz / 2): ss = "" for j in range(tgt_len[i]): ss = ss + tgt_w[i][ j + 1] + '(' + '%.2f' % batch_logpdf[i][j].item() + ') ' logprob_lis.append(batch_logpdf[i][j].item()) if batch_logpdf[i][j] < -1000: print tgt_w[i][j + 1], batch_logpdf[i][j] w_lis.append(tgt_w[i][j + 1]) #print ss #torch.save(logprob_lis, 'figs/advtrain_wordlogps/naivetrainR' + str(ADV_RATIO) + '_' + SUC + '.data') #torch.save(w_lis, 'figs/advtrain_wordlogps/naivetrainR' + str(ADV_RATIO) + '_' + SUC + '_w.data') #sys.exit(1) #===end=== #print torch.min(batch_logpdf) w_loss_rnn = torch.sum(batch_logpdf) loss_sen.extend( torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist()) all_loss = all_loss + w_loss_rnn.data.item() if do_train == True: for m in m_dict.values(): m.zero_grad() (-w_loss_rnn / sum(tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt.step() if do_train == True and do_adv == True and b_count % ADV_RATIO == 0: target_mb = next(inf_adv_batches) src_mb, tgt_mb, tgt_len, src_lis, tgt_lis = get_adv_seq2seq_mb( target_mb, ADV_ATTACK, m_dict, adv_config) #print tgt_lis[0], tgt_mb[0], tgt_len[0] for m in m_dict: #get_adv_seq2seq_mb could change the flags m_dict[m].zero_grad() m_dict[m].train() adv_batch_logpdf = models.encoder_decoder_forward( src_mb, tgt_mb, tgt_len, m_dict, adv_config) decay_co = get_decay_co(adv_batch_logpdf, tgt_len, adv_config) nodecay_sennum += np.sum(decay_co.cpu().numpy() == 1) adv_b_count += 1 sen_logpdf = torch.sum(adv_batch_logpdf, dim=1) * decay_co (ADV_LAMBDA * torch.sum(sen_logpdf) / sum(tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt.step() adv_all_loss += torch.sum(adv_batch_logpdf).item() adv_all_num += sum(tgt_len) if do_log == True and b_count % LOG_INTERVAL == 0: logger.info('avg loss at b: %d , %f', b_count, all_loss * 1.0 / all_num) logger.info('all_num: %d', all_num) logger.info('sen_avg_loss: %f', np.mean(loss_sen)) if do_adv == True: logger.info('adv_all_num: %d avg_nodecay_sennum: %d', adv_all_num, float(nodecay_sennum * 1.0) / adv_b_count) logger.info('avg_adv_loss: %f', adv_all_loss / adv_all_num) if all_num != 0: return float(all_loss * 1.0 / all_num) else: return 0
opt = torch.optim.SGD(all_params, momentum=0.9, lr=0, weight_decay=1e-5) loss_test = mle_train(batches_test, opt, do_train=False, adv_config=adv_config) logger.info('test PPL: %f log-likelihood: %f', math.exp(-loss_test), loss_test) res = {'loss_test': loss_test, 'mb_lis': []} batches_test = DialogueBatches([TEST_FN], BATCH_SIZE, SRC_SEQ_LEN, TGT_SEQ_LEN, vocab_inv, DATA_CONFIG, his_len=HIS) for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches_test: batch_logpdf = models.encoder_decoder_forward(src_mb, tgt_mb, tgt_len, m_dict, adv_config) res['mb_lis'].append( (src_mb.detach().cpu(), tgt_mb.detach().cpu(), tgt_len, src_w, tgt_w, batch_logpdf.detach().cpu())) res_save_fn = 'figs/post_mal_advtrain/' + DATA_SET + '_IT' + str( TEST_ITER) + '_NEGLR' + str(NEG_LR) + '_POSLR' + str( POS_LR) + '_FAVOID' + str(FREQ_AVOID) + '_ILM' + str( ADV_I_LM_FLAG) + '_MFOCUS' + str( MIDDLE_FOCUS) + '_CARE' + ADV_CARE_MODE + '_testlogp.save' print 'saving to res_save_fn:', res_save_fn torch.save(res, res_save_fn)
def adv_train(adv_batches, positive_batches, opts, do_train=True, do_log=True, do_adv=False, adv_config=None): #adv_train put adversarial training first #assert(do_train == True) opt_pos, opt_neg = opts all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum, success0_co, success1_co = 0, 0, 0, 0, 0, 0, 0 b_count, adv_b_count, adv_sen_count = 0, 0, 0 attack_results = [] loss_sen = [] all_target_set = {} for target_mb in adv_batches: adv_src_mb, adv_tgt_mb, adv_tgt_len, adv_src_lis, adv_tgt_lis = get_adv_seq2seq_mb( target_mb, ADV_ATTACK, m_dict, adv_config) for l in adv_tgt_lis: all_target_set[' '.join(l[1:])] = True nondecay_lis_now = [] co_now = 0 while 1 == 1: #print tgt_lis[0], tgt_mb[0], tgt_len[0] for m in m_dict: #get_adv_seq2seq_mb could change the flags m_dict[m].zero_grad() m_dict[m].train() aux = {} adv_batch_logpdf = models.encoder_decoder_forward(adv_src_mb, adv_tgt_mb, adv_tgt_len, m_dict, adv_config, aux_return=aux) w_logit_rnn = aux['w_logit_rnn'] decay_co = get_decay_co(adv_batch_logpdf, w_logit_rnn, adv_tgt_mb, adv_tgt_len, adv_config) attack_results.append( (adv_src_mb, adv_tgt_mb, adv_tgt_len, adv_src_lis, adv_tgt_lis, decay_co, adv_batch_logpdf)) final_mask = torch.FloatTensor(adv_batch_logpdf.size()).cuda() final_mask[:] = 1 if MIDDLE_FOCUS == True: for i in range(len(adv_tgt_len)): final_mask[i][0] = 0 final_mask[i][adv_tgt_len[i] - 1] = 0 if FREQ_AVOID == True: for i in range(len(adv_tgt_len)): for j in range(adv_tgt_len[i]): if adv_tgt_lis[i][j + 1] in FREQ_AVOID_LIS: final_mask[i][j] = 0 """ for i in range(20): print i, 'adv_tgt_lis:', adv_tgt_lis[i] print 'mask:', final_mask[i] print 'adv_tgt_len', adv_tgt_len[i] print 'adv_batch_logpdf', adv_batch_logpdf[i] print '*final_mask', (adv_batch_logpdf * final_mask)[i] """ sen_logpdf = torch.sum(adv_batch_logpdf * final_mask, dim=1) * decay_co (ADV_LAMBDA * torch.sum(sen_logpdf) / sum(adv_tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) if do_train == True: opt_neg.step() """ #just for debug!! for i in range(decay_co.size(0)): if decay_co[i] == 1: _, tgt_w, _ = target_mb s_lis = adv_src_mb[i].cpu().numpy().tolist() print 'hit! target:', ' '.join(tgt_w[i]), 't_input:', ' '.join([vocab[ii] for ii in s_lis]) print 'hit hit hit in first mb' """ nondecay_num_now = np.sum(decay_co.cpu().numpy() == 1) #logger.info('co_now: %d nondecay_num_now: %d', co_now, nondecay_num_now) nondecay_lis_now.append(nondecay_num_now) if co_now == 0: logger.info('displaying hit targets') for i in range(len(adv_tgt_len)): if decay_co[i] == 1: print "hit! target:", ' '.join( adv_tgt_lis[i]), 'trigger:', ' '.join( adv_src_lis[i]) if co_now == 0: adv_sen_count += adv_src_mb.size(0) adv_all_loss += torch.sum(adv_batch_logpdf).item() adv_all_num += sum(adv_tgt_len) nodecay_sennum += nondecay_num_now adv_b_count += 1 if do_train == False: break co_now += 1 for kk in range(POSITIVE_RATIO): src_mb, tgt_mb, tgt_len, src_w, tgt_w = positive_batches.next() loss = 0 b_count = b_count + 1 bz = src_mb.size(0) all_num = all_num + sum(tgt_len) batch_logpdf = models.encoder_decoder_forward( src_mb, tgt_mb, tgt_len, m_dict, adv_config) w_loss_rnn = torch.sum(batch_logpdf) loss_sen.extend( torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist()) all_loss = all_loss + w_loss_rnn.data.item() if do_train == True: for m in m_dict.values(): m.zero_grad() (-w_loss_rnn / sum(tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt_pos.step() if nondecay_num_now == 0 or co_now > 20: if nondecay_num_now <= 0: success0_co += 1 if nondecay_num_now <= 1: success1_co += 1 break logger.info('current mb nondecay_lis_now(through each update): %s', str(nondecay_lis_now)) logger.info('all_num: %d', all_num) if all_num > 0: logger.info('sen_avg_loss: %f', np.mean(loss_sen)) if do_adv == True: logger.info( 'iter: %d adv_all_num: %d avg_nodecay_sennum: %d success0_rate: %f success1_rate: %f', adv_config['iter_now'], adv_all_num, float(nodecay_sennum * 1.0) / adv_b_count, float(success0_co * 100.0 / adv_b_count), float(success1_co * 100.0 / adv_b_count)) logger.info('debug for avg_nodecay_sennum: %d / %d', nodecay_sennum, adv_b_count) logger.info('avg_adv_loss: %f', adv_all_loss / adv_all_num) res = { 'positive_avg_loss': 0 if all_num == 0 else all_loss * 1.0 / all_num, 'nodecay_sennum': nodecay_sennum, 'adv_sen_count': adv_sen_count, 'attack_results': attack_results, 'all_target_set': all_target_set, } return res
def neg_d_train(aux_batches, train_batches, opts, do_train=True, do_log=True, do_adv=False, adv_config=None): #adv_train put adversarial training first #assert(do_train == True) opt_pos, opt_neg, opt_d = opts all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum, success0_co, success1_co = 0, 0, 0, 0, 0, 0, 0 attack_results = [] loss_sen = [] all_target_set = {} inf_pos_batches, inf_d_batches = aux_batches d_b_co, pos_b_co, neg_b_co = 0, 0, 0 sample_mb_longhis = [] stat_dic = MyStatDic() for train_mb in train_batches: neg_b_co += 1 src_mb, tgt_mb, tgt_len, src_w, tgt_w = train_mb bz = src_mb.size(0) res_sample = models.get_samples([train_mb], TRAIN_SAMPLE_TYPE, m_dict, adv_config) sample_mb, samples_w = res_sample['raw_sample_lis'][0], res_sample[ 'sample_lis'] neg_lambda = [0 for kk in range(len(samples_w))] sample_mb_longhis.append(samples_w) if len(sample_mb_longhis) > 200: sample_mb_longhis = sample_mb_longhis[1:] d_r = make_ratio_dict(sample_mb_longhis) for i in range(BATCH_SIZE): #print samples_w[i] if d_r[' '.join(samples_w[i])] > R_THRES: neg_lambda[i] = 1 #print 'found frequent response:', ' '.join(samples_w[i]), 'ratio:', d_r[' '.join(samples_w[i])] neg_lambda = torch.FloatTensor(neg_lambda).cuda() sample_tgt_mb, sample_tgt_len = form_tgtmb(sample_mb) for m in m_dict.values(): m.zero_grad() m.train() aux = {} adv_batch_logpdf = models.encoder_decoder_forward(src_mb, sample_tgt_mb, sample_tgt_len, m_dict, adv_config, aux_return=aux) w_logit_rnn = aux['w_logit_rnn'] final_mask = torch.FloatTensor(adv_batch_logpdf.size()).cuda() final_mask[:] = 1 if FREQ_AVOID == True: for i in range(len(sample_tgt_len)): final_mask[i][sample_tgt_len[i] - 1] = FREQ_AVOID_SCAL #for i in range(5): # print sample_tgt_mb[i] # print adv_batch_logpdf[i] # print final_mask[i] #sys.exit(1) neg_logpdf = torch.sum(adv_batch_logpdf * final_mask, dim=1) * neg_lambda if do_train == True: (torch.sum(neg_logpdf) / sum(sample_tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt_neg.step() stat_dic.append_dict({'neg_lambda': neg_lambda.mean().item()}) stat_dic.append_dict({ 'neg_loss': (torch.sum(adv_batch_logpdf).item() / sum(sample_tgt_len)) }) if neg_b_co % 2000 == 0 and DEBUG_INFO == True: print 'peek mb sample!' for i in range(sample_tgt_mb.size(0) / 5): logger.info( 'neg_lambda: %f %s', neg_lambda[i].item(), ' '.join([ vocab[sample_tgt_mb[i][j]] for j in range(sample_tgt_len[i] + 1) ])) #wait = raw_input("PRESS ENTER") src_mb, tgt_mb, tgt_len, src_w, tgt_w = train_mb for kk in range(D_RATIO): res_d = d_train_mb(inf_d_batches.next(), opt_d, do_train=do_train) d_b_co += 1 stat_dic.append_dict(res_d, keys=['d_loss', 'd_acc_ratio']) for kk in range(POSITIVE_RATIO): src_mb, tgt_mb, tgt_len, src_w, tgt_w = inf_pos_batches.next() pos_b_co += 1 bz = src_mb.size(0) for m in m_dict.values(): m.train() m.zero_grad() batch_logpdf = models.encoder_decoder_forward( src_mb, tgt_mb, tgt_len, m_dict, adv_config) w_loss_rnn = torch.sum(batch_logpdf) loss_sen.extend( torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist()) #all_loss = all_loss + w_loss_rnn.data.item() stat_dic.append_dict( {'pos_loss': (-w_loss_rnn.item() / sum(tgt_len))}) if do_train == True: (-w_loss_rnn / sum(tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt_pos.step() if neg_b_co % 200 == 0: stat_dic.log_mean(last_num=200, log_pre='it{} neg_b_co: {}'.format( adv_config['iter_now'], neg_b_co)) return stat_dic