def encoder_decoder_forward(src_mb, tgt_mb, tgt_len, m_dict, config, aux_return = None): MT, HIDDEN_SIZE, vocab, LAYER_NUM = config['MT'], config['HIDDEN_SIZE'], config['vocab'], config['LAYER_NUM'] m_encode_w_rnn, m_decode_w_rnn, m_embed = m_dict['m_encode_w_rnn'], m_dict['m_decode_w_rnn'], m_dict['m_embed'] bz = src_mb.size(0) src_inputv = Variable(src_mb).cuda() tgt_inputv = Variable(tgt_mb[:, :-1]).cuda() tgt_targetv = Variable(tgt_mb[:, 1:]).cuda() tgt_mask = Variable(torch.FloatTensor([([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len]).cuda(), requires_grad = False) output, _ = m_encode_w_rnn(m_embed(idx2onehot(src_inputv, len(vocab))).permute(1, 0, 2), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num = LAYER_NUM)) #for parallel! e_output = output.permute(1, 0, 2) #for parallel! if MT == 'latent': latent = e_output[:, -1, :].squeeze(1) latent = latent.unsqueeze(1).repeat(1, tgt_inputv.size(1), 1) w_logit_rnn = m_decode_w_rnn(latent, tgt_inputv, tgt_len) #change from decode to forward for data parallel if MT == 'attention': w_logit_rnn, attn_weights, _ = m_decode_w_rnn(e_output, tgt_inputv) if aux_return != None: aux_return['w_logit_rnn'] = w_logit_rnn flat_output = w_logit_rnn.view(-1, len(vocab)) flat_target = tgt_targetv.contiguous().view(-1) flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target) batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask return batch_logpdf
def adversarial_latent_optimize(lis_tgt_w): bz = len(lis_tgt_w) tgt_len = [len(l) - 1 for l in lis_tgt_w] max_len = max(tgt_len) #latent_v = Variable(torch.randn(bz, max_len, HIDDEN_SIZE).cuda(), requires_grad = True) latent_v_ori = Variable(torch.randn(bz, 1, HIDDEN_SIZE).cuda(), requires_grad=True) latent_v = latent_v_ori.repeat(1, max_len, 1) latent_opt = torch.optim.SGD( [latent_v_ori], momentum=0.9, lr=1, weight_decay=1e-5) #the larger the better, it seems lis_tgt_w = [l + ['<pad>'] * (max_len + 1 - len(l)) for l in lis_tgt_w] lis_tgt_idx = [[vocab_inv[w] for w in l] for l in lis_tgt_w] tgt_mb = Variable(torch.LongTensor(lis_tgt_idx).cuda()) tgt_inputv = tgt_mb[:, :-1] tgt_targetv = tgt_mb[:, 1:] tgt_mask_ori = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) for it in range(4000 + 1): #print latent_v.size(), tgt_inputv.size(), tgt_len latent_v = latent_v_ori.repeat(1, max_len, 1) w_logit_rnn = m_decode_w_rnn.forward(F.tanh(latent_v), tgt_inputv) flat_output = w_logit_rnn.view(-1, len(vocab)) flat_target = tgt_targetv.contiguous().view(-1) flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target) tgt_mask = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) w_logit_pred = torch.max(w_logit_rnn, dim=2)[1] for i in range(bz): for j in range(tgt_len[i]): if w_logit_pred[i][j] == tgt_targetv[i][j]: tgt_mask[i][j] = WEIGHT_LOSS_DECAY batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask w_loss_rnn = torch.sum(batch_logpdf) w_loss_rnn_ori = torch.sum(flat_logpdf.view(bz, -1) * tgt_mask_ori) avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0] latent_opt.zero_grad() (-w_loss_rnn / sum(tgt_len)).backward() latent_opt.step() avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0] if 1 == 1 and it % 500 == 0: logger.info('it %d avg_loss: %f', it, avg_loss) return latent_v, avg_loss
def lm_model_forward(model, inputv, targetv, b_len, vocab, do_train=False): if do_train == False: model.eval() else: model.train() bz = inputv.size(0) maskv = Variable(mask_gen(b_len, ty='Float')).cuda() #size(batch, length) output, _ = model(idx2onehot(inputv, len(vocab)), model.initHidden(batch_size=inputv.size(0))) w_logpdf = lib_pdf.logsoftmax_idxselect( output.view(-1, len(vocab)), targetv.contiguous().view(-1)).view(bz, -1) w_logpdf = w_logpdf * maskv return w_logpdf
def adv_model_forward(input_onehot_v, input_idx_lis, target_mb, m_dict, adv_config, decay_loss=True, do_train=True): bz = input_onehot_v.size(0) globals().update(adv_config) globals().update(m_dict) m_list = [ m_embed, m_encode_w_rnn, ADV_I_LM, m_decode_w_rnn, m_embed_dp, m_encode_w_rnn_dp, m_decode_w_rnn_dp ] if do_train == True: for m in m_list: if m != None: m.train() if do_train == False: for m in m_list: if m != None: m.eval() tgt_idx, tgt_w, tgt_len = target_mb tgt_inputv = tgt_idx[:, :-1].cuda() tgt_targetv = tgt_idx[:, 1:].cuda() output, _ = m_encode_w_rnn_dp( m_embed_dp(input_onehot_v).permute(1, 0, 2), init_lstm_hidden(bz, HIDDEN_SIZE)) output = output.permute(1, 0, 2) latent = output[:, -1, :].unsqueeze(1).repeat(1, tgt_targetv.size(1), 1) if MT == 'latent': w_logit_rnn = m_decode_w_rnn_dp(latent, tgt_inputv, tgt_len) if MT == 'attention': w_logit_rnn, attn_weights, _ = m_decode_w_rnn_dp(output, tgt_inputv) flat_output = w_logit_rnn.view(-1, len(vocab)) flat_target = tgt_targetv.contiguous().view(-1) flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target) batch_logpdf = flat_logpdf.view(bz, -1) #print 'tgt_targetv[8,9,10]', tgt_targetv[8], tgt_targetv[9], tgt_targetv[10] #print tgt_w[8], tgt_w[9] tgt_mask = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) w_logit_pred = torch.max(w_logit_rnn, dim=2)[1] #if decay_loss == True: #print 'decay_weight:', GIBBSENUM_DECAYLOSS_WEIGHT o_mins = [] for i in range(bz): minn = 1000 for j in range(tgt_len[i]): if ADV_CARE_MODE == 'max' and decay_loss == True and w_logit_pred[ i][j] == tgt_targetv[i][j]: tgt_mask[i][j] = min(tgt_mask[i][j], 0.01) #debug! #why not just zero? if ADV_CARE_MODE == 'sample_min' and decay_loss == True and batch_logpdf[ i][j].item() >= NORMAL_WORD_AVG_LOSS: tgt_mask[i][j] = min(tgt_mask[i][j], 0.01) if batch_logpdf[i][j].item() < minn: minn = batch_logpdf[i][j].item() #if tgt_targetv[i][j] != 0: #logger.info('setting to 0.001: |%s|', vocab[tgt_targetv[i][j]]) o_mins.append(minn) weight_batch_logpdf = batch_logpdf * tgt_mask sen_loss_rnn = torch.sum(weight_batch_logpdf, dim=1) w_loss_rnn = sen_loss_rnn / torch.FloatTensor(tgt_len).cuda() # if ADV_I_LM_FLAG == True: i_w_loss_lm = ADV_I_LM.calMeanLogp(input_idx_lis, input_onehot_v, mode='input_eou', train_flag=do_train) else: i_w_loss_lm = None o_scal = Variable(torch.ones(bz).cuda(), requires_grad=False) i_scal = Variable(torch.ones(bz).cuda(), requires_grad=False) #print i_w_loss_lm.size() for k in range(bz): if ADV_I_LM_FLAG == True and decay_loss == True: if i_w_loss_lm[k].item() > GE_I_WORD_AVG_LOSS: i_scal[k] = 0.01 if ADV_CARE_MODE == 'sample_avg' and w_loss_rnn[k].item( ) > NORMAL_WORD_AVG_LOSS: o_scal[k] = 0.01 if ADV_CARE_MODE == 'sample_min' and o_mins[k] > NORMAL_WORD_AVG_LOSS: o_scal[k] = 0.01 if ADV_I_LM_FLAG == False: loss_combine = w_loss_rnn else: loss_combine = w_loss_rnn * o_scal + i_w_loss_lm * i_scal * GIBBSENUM_I_LM_LAMBDA return w_logit_rnn, w_logit_pred, sen_loss_rnn, w_loss_rnn, loss_combine, i_w_loss_lm, o_mins, batch_logpdf
def adversarial_input_optimize(lis_tgt_w, mode='embed'): assert (mode == 'embed' or mode == 'softmax_idx' or mode == 'linear_idx' or mode == 'softplus_idx' or mode == 'sigmoid_idx') bz = len(lis_tgt_w) tgt_len = [len(l) - 1 for l in lis_tgt_w] max_len = max(tgt_len) v_list = [] if mode == 'embed': embed_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY, EMBED_SIZE).cuda(), requires_grad=True) v_list.append(embed_v) elif mode == 'softmax_idx': ll = len(vocab) if SOFTMAX_IDX_EL >= 0: ll = SOFTMAX_IDX_EL onehot_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY, ll).cuda(), requires_grad=True) v_list.append(onehot_v) elif mode == 'sigmoid_idx': onehot_v = Variable( torch.randn(bz, ADV_SRC_LEN_TRY, len(vocab)).cuda() - 3, requires_grad=True) v_list.append(onehot_v) elif mode == 'linear_idx' or mode == 'softplus_idx': onehot_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY, len(vocab)).cuda(), requires_grad=True) v_list.append(onehot_v) #latent_v_ori = Variable(torch.randn(bz, 1, HIDDEN_SIZE).cuda(), requires_grad = True) #latent_v = latent_v_ori.repeat(1, max_len, 1) lis_tgt_w = [l + ['<pad>'] * (max_len + 1 - len(l)) for l in lis_tgt_w] lis_tgt_idx = [[vocab_inv[w] for w in l] for l in lis_tgt_w] tgt_mb = Variable(torch.LongTensor(lis_tgt_idx).cuda()) tgt_inputv = tgt_mb[:, :-1] tgt_targetv = tgt_mb[:, 1:] tgt_mask_ori = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) lr = 1 * bz #when the batch is larger than one, the learning rate becomes small, for mode 'embed' lr=1 works best for epoch in range(8): #8!! opt_v = torch.optim.SGD( v_list, momentum=0.9, lr=lr, weight_decay=1e-5) #the larger the better, it seems #opt_v = torch.optim.Adam(v_list, lr = 1e-4) #debug adam! #lr = lr * 0.6 #exp shows const 1 is better~ for it in range(500 * epoch, 500 * (epoch + 1)): #print latent_v.size(), tgt_inputv.size(), tgt_len #latent_v = latent_v_ori.repeat(1, max_len, 1) if mode == 'embed': output, _ = m_encode_w_rnn( embed_v, init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) elif mode == 'softmax_idx': #output, _ = m_encode_w_rnn(m_embed(F.softmax(onehot_v, dim = 2)), init_lstm_hidden(bz, HIDDEN_SIZE)) output, _ = m_encode_w_rnn( m_embed(softmax_idx_el_glue(onehot_v)), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) elif mode == 'linear_idx': output, _ = m_encode_w_rnn( m_embed(onehot_v), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) elif mode == 'softplus_idx': output, _ = m_encode_w_rnn( m_embed(F.softplus(onehot_v)), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) elif mode == 'sigmoid_idx': output, _ = m_encode_w_rnn( m_embed(F.sigmoid(onehot_v)).permute(1, 0, 2), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) output = output.permute(1, 0, 2) #latent = output[:, -1, :].unsqueeze(1).repeat(1, tgt_targetv.size(1), 1) #latent = Variable(torch.zeros(latent.size())).cuda() #debug! #print 'hn:', hn.size() #[1, bz, HIDDEN_SIZE] w_logit_rnn, attn_weights = decoder_forward( output, tgt_inputv, tgt_len) #w_logit_rnn = m_decode_w_rnn.decode(latent, tgt_inputv, tgt_len) flat_output = w_logit_rnn.view(-1, len(vocab)) flat_target = tgt_targetv.contiguous().view(-1) flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target) tgt_mask = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) w_logit_pred = torch.max(w_logit_rnn, dim=2)[1] for i in range(bz): for j in range(tgt_len[i]): if w_logit_pred[i][j] == tgt_targetv[i][j]: tgt_mask[i][j] = WEIGHT_LOSS_DECAY #if tgt_targetv[i][j] != 0: # logger.info('setting to 0.01: |%s|', vocab[tgt_targetv[i][j]]) batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask w_loss_rnn = torch.sum(batch_logpdf) w_loss_rnn_ori = torch.sum(flat_logpdf.view(bz, -1) * tgt_mask_ori) opt_v.zero_grad() (-w_loss_rnn / sum(tgt_len)).backward(retain_graph=True) if (mode == 'sigmoid_idx' or mode == 'softmax_idx') and LASSO_LAMBDA > 0: if mode == 'sigmoid_idx': sv = F.sigmoid(onehot_v) if mode == 'softmax_idx': sv = F.softmax(onehot_v, dim=2) lasso_loss = LASSO_LAMBDA * torch.sum(sv, dim=2).view(-1).mean() lasso_loss += (-LASSO_LAMBDA * 2) * (torch.max( sv, dim=2)[0]).view(-1).mean() lasso_loss.backward(retain_graph=True) #print torch.max(torch.abs(onehot_v.grad)) #softmaxed gradient is very small opt_v.step() avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0] if 1 == 1 and it % 200 == 0: logger.info('it %d avg_loss: %f', it, avg_loss) if mode == 'embed': return embed_v, avg_loss elif mode == 'softmax_idx' or mode == 'linear_idx' or mode == 'softplus_idx' or mode == 'sigmoid_idx': return onehot_v, avg_loss
def mle_train(batches, opt, do_train=True, do_log=True): for m in m_dict: if do_train == True: m_dict[m].train() else: m_dict[m].eval() all_loss = 0 all_num = 0 b_count = 0 loss_sen = [] for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches: #print src_w[0], src_w[1]; sys.exit(1) loss = 0 b_count = b_count + 1 bz = src_mb.size(0) all_num = all_num + sum(tgt_len) src_inputv = Variable(src_mb).cuda() tgt_inputv = Variable(tgt_mb[:, :-1]).cuda() tgt_targetv = Variable(tgt_mb[:, 1:]).cuda() tgt_mask = Variable(torch.FloatTensor([ ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len ]).cuda(), requires_grad=False) #mask = Variable(mask_gen(b_len)).cuda() #size(batch, length) output, _ = m_encode_w_rnn_dp( m_embed_dp(idx2onehot(src_inputv, len(vocab))).permute(1, 0, 2), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM)) #for parallel! output = output.permute(1, 0, 2) #for parallel! w_logit_rnn, attn_weights = decoder_forward(output, tgt_inputv, tgt_len) flat_output = w_logit_rnn.view(-1, len(vocab)) flat_target = tgt_targetv.contiguous().view(-1) flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target) batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask w_loss_rnn = torch.sum(batch_logpdf) loss_sen.extend( torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist()) all_loss = all_loss + w_loss_rnn.data.item() if do_train == True: for m in m_dict.values(): m.zero_grad() (-w_loss_rnn / sum(tgt_len)).backward() for m in m_dict.values(): torch.nn.utils.clip_grad_norm_(m.parameters(), 5) opt.step() if do_log == True and b_count % LOG_INTERVAL == 0: logger.info('avg loss at b: %d , %f', b_count, all_loss * 1.0 / all_num) logger.info('all_num: %d', all_num) logger.info('sen_avg_loss: %f', np.mean(loss_sen)) return float(all_loss * 1.0 / all_num)