def encoder_decoder_forward(src_mb, tgt_mb, tgt_len, m_dict, config, aux_return = None):
    MT, HIDDEN_SIZE, vocab, LAYER_NUM = config['MT'], config['HIDDEN_SIZE'], config['vocab'], config['LAYER_NUM']
    m_encode_w_rnn, m_decode_w_rnn, m_embed = m_dict['m_encode_w_rnn'], m_dict['m_decode_w_rnn'], m_dict['m_embed']
    bz = src_mb.size(0)
    src_inputv = Variable(src_mb).cuda() 
    tgt_inputv = Variable(tgt_mb[:, :-1]).cuda() 
    tgt_targetv = Variable(tgt_mb[:, 1:]).cuda()
    tgt_mask = Variable(torch.FloatTensor([([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len]).cuda(), requires_grad = False)

    output, _ = m_encode_w_rnn(m_embed(idx2onehot(src_inputv, len(vocab))).permute(1, 0, 2), init_lstm_hidden(bz, HIDDEN_SIZE, layer_num = LAYER_NUM)) #for parallel!
    e_output = output.permute(1, 0, 2) #for parallel!
    
    if MT == 'latent':
        latent = e_output[:, -1, :].squeeze(1)
        latent = latent.unsqueeze(1).repeat(1, tgt_inputv.size(1), 1)
        w_logit_rnn = m_decode_w_rnn(latent, tgt_inputv, tgt_len) #change from decode to forward for data parallel
    if MT == 'attention':
        w_logit_rnn, attn_weights, _ = m_decode_w_rnn(e_output, tgt_inputv)
    
    if aux_return != None:
        aux_return['w_logit_rnn'] = w_logit_rnn

    flat_output = w_logit_rnn.view(-1, len(vocab))
    flat_target = tgt_targetv.contiguous().view(-1)
    flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target)
    batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask
     
    return batch_logpdf
Exemplo n.º 2
0
def adversarial_latent_optimize(lis_tgt_w):
    bz = len(lis_tgt_w)
    tgt_len = [len(l) - 1 for l in lis_tgt_w]
    max_len = max(tgt_len)
    #latent_v = Variable(torch.randn(bz, max_len, HIDDEN_SIZE).cuda(), requires_grad = True)
    latent_v_ori = Variable(torch.randn(bz, 1, HIDDEN_SIZE).cuda(),
                            requires_grad=True)
    latent_v = latent_v_ori.repeat(1, max_len, 1)
    latent_opt = torch.optim.SGD(
        [latent_v_ori], momentum=0.9, lr=1,
        weight_decay=1e-5)  #the larger the better, it seems
    lis_tgt_w = [l + ['<pad>'] * (max_len + 1 - len(l)) for l in lis_tgt_w]
    lis_tgt_idx = [[vocab_inv[w] for w in l] for l in lis_tgt_w]
    tgt_mb = Variable(torch.LongTensor(lis_tgt_idx).cuda())
    tgt_inputv = tgt_mb[:, :-1]
    tgt_targetv = tgt_mb[:, 1:]
    tgt_mask_ori = Variable(torch.FloatTensor([
        ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
    ]).cuda(),
                            requires_grad=False)

    for it in range(4000 + 1):
        #print latent_v.size(), tgt_inputv.size(), tgt_len
        latent_v = latent_v_ori.repeat(1, max_len, 1)
        w_logit_rnn = m_decode_w_rnn.forward(F.tanh(latent_v), tgt_inputv)
        flat_output = w_logit_rnn.view(-1, len(vocab))
        flat_target = tgt_targetv.contiguous().view(-1)
        flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target)
        tgt_mask = Variable(torch.FloatTensor([
            ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
        ]).cuda(),
                            requires_grad=False)
        w_logit_pred = torch.max(w_logit_rnn, dim=2)[1]
        for i in range(bz):
            for j in range(tgt_len[i]):
                if w_logit_pred[i][j] == tgt_targetv[i][j]:
                    tgt_mask[i][j] = WEIGHT_LOSS_DECAY

        batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask
        w_loss_rnn = torch.sum(batch_logpdf)

        w_loss_rnn_ori = torch.sum(flat_logpdf.view(bz, -1) * tgt_mask_ori)
        avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0]

        latent_opt.zero_grad()
        (-w_loss_rnn / sum(tgt_len)).backward()
        latent_opt.step()

        avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0]
        if 1 == 1 and it % 500 == 0:
            logger.info('it %d avg_loss: %f', it, avg_loss)

    return latent_v, avg_loss
Exemplo n.º 3
0
def lm_model_forward(model, inputv, targetv, b_len, vocab, do_train=False):
    if do_train == False:
        model.eval()
    else:
        model.train()

    bz = inputv.size(0)
    maskv = Variable(mask_gen(b_len, ty='Float')).cuda()
    #size(batch, length)
    output, _ = model(idx2onehot(inputv, len(vocab)),
                      model.initHidden(batch_size=inputv.size(0)))

    w_logpdf = lib_pdf.logsoftmax_idxselect(
        output.view(-1, len(vocab)),
        targetv.contiguous().view(-1)).view(bz, -1)
    w_logpdf = w_logpdf * maskv

    return w_logpdf
Exemplo n.º 4
0
def adv_model_forward(input_onehot_v,
                      input_idx_lis,
                      target_mb,
                      m_dict,
                      adv_config,
                      decay_loss=True,
                      do_train=True):
    bz = input_onehot_v.size(0)
    globals().update(adv_config)
    globals().update(m_dict)
    m_list = [
        m_embed, m_encode_w_rnn, ADV_I_LM, m_decode_w_rnn, m_embed_dp,
        m_encode_w_rnn_dp, m_decode_w_rnn_dp
    ]
    if do_train == True:
        for m in m_list:
            if m != None: m.train()
    if do_train == False:
        for m in m_list:
            if m != None: m.eval()

    tgt_idx, tgt_w, tgt_len = target_mb
    tgt_inputv = tgt_idx[:, :-1].cuda()
    tgt_targetv = tgt_idx[:, 1:].cuda()

    output, _ = m_encode_w_rnn_dp(
        m_embed_dp(input_onehot_v).permute(1, 0, 2),
        init_lstm_hidden(bz, HIDDEN_SIZE))
    output = output.permute(1, 0, 2)
    latent = output[:, -1, :].unsqueeze(1).repeat(1, tgt_targetv.size(1), 1)

    if MT == 'latent':
        w_logit_rnn = m_decode_w_rnn_dp(latent, tgt_inputv, tgt_len)
    if MT == 'attention':
        w_logit_rnn, attn_weights, _ = m_decode_w_rnn_dp(output, tgt_inputv)

    flat_output = w_logit_rnn.view(-1, len(vocab))
    flat_target = tgt_targetv.contiguous().view(-1)
    flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target)
    batch_logpdf = flat_logpdf.view(bz, -1)
    #print 'tgt_targetv[8,9,10]', tgt_targetv[8], tgt_targetv[9], tgt_targetv[10]
    #print tgt_w[8], tgt_w[9]

    tgt_mask = Variable(torch.FloatTensor([
        ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
    ]).cuda(),
                        requires_grad=False)
    w_logit_pred = torch.max(w_logit_rnn, dim=2)[1]
    #if decay_loss == True:
    #print 'decay_weight:', GIBBSENUM_DECAYLOSS_WEIGHT
    o_mins = []
    for i in range(bz):
        minn = 1000
        for j in range(tgt_len[i]):
            if ADV_CARE_MODE == 'max' and decay_loss == True and w_logit_pred[
                    i][j] == tgt_targetv[i][j]:
                tgt_mask[i][j] = min(tgt_mask[i][j],
                                     0.01)  #debug! #why not just zero?
            if ADV_CARE_MODE == 'sample_min' and decay_loss == True and batch_logpdf[
                    i][j].item() >= NORMAL_WORD_AVG_LOSS:
                tgt_mask[i][j] = min(tgt_mask[i][j], 0.01)
            if batch_logpdf[i][j].item() < minn:
                minn = batch_logpdf[i][j].item()
            #if tgt_targetv[i][j] != 0:
            #logger.info('setting to 0.001: |%s|', vocab[tgt_targetv[i][j]])
        o_mins.append(minn)
    weight_batch_logpdf = batch_logpdf * tgt_mask
    sen_loss_rnn = torch.sum(weight_batch_logpdf, dim=1)
    w_loss_rnn = sen_loss_rnn / torch.FloatTensor(tgt_len).cuda()  #

    if ADV_I_LM_FLAG == True:
        i_w_loss_lm = ADV_I_LM.calMeanLogp(input_idx_lis,
                                           input_onehot_v,
                                           mode='input_eou',
                                           train_flag=do_train)
    else:
        i_w_loss_lm = None
    o_scal = Variable(torch.ones(bz).cuda(), requires_grad=False)
    i_scal = Variable(torch.ones(bz).cuda(), requires_grad=False)
    #print i_w_loss_lm.size()
    for k in range(bz):
        if ADV_I_LM_FLAG == True and decay_loss == True:
            if i_w_loss_lm[k].item() > GE_I_WORD_AVG_LOSS: i_scal[k] = 0.01
        if ADV_CARE_MODE == 'sample_avg' and w_loss_rnn[k].item(
        ) > NORMAL_WORD_AVG_LOSS:
            o_scal[k] = 0.01
        if ADV_CARE_MODE == 'sample_min' and o_mins[k] > NORMAL_WORD_AVG_LOSS:
            o_scal[k] = 0.01

    if ADV_I_LM_FLAG == False:
        loss_combine = w_loss_rnn
    else:
        loss_combine = w_loss_rnn * o_scal + i_w_loss_lm * i_scal * GIBBSENUM_I_LM_LAMBDA

    return w_logit_rnn, w_logit_pred, sen_loss_rnn, w_loss_rnn, loss_combine, i_w_loss_lm, o_mins, batch_logpdf
Exemplo n.º 5
0
def adversarial_input_optimize(lis_tgt_w, mode='embed'):
    assert (mode == 'embed' or mode == 'softmax_idx' or mode == 'linear_idx'
            or mode == 'softplus_idx' or mode == 'sigmoid_idx')
    bz = len(lis_tgt_w)
    tgt_len = [len(l) - 1 for l in lis_tgt_w]
    max_len = max(tgt_len)
    v_list = []
    if mode == 'embed':
        embed_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY, EMBED_SIZE).cuda(),
                           requires_grad=True)
        v_list.append(embed_v)
    elif mode == 'softmax_idx':
        ll = len(vocab)
        if SOFTMAX_IDX_EL >= 0:
            ll = SOFTMAX_IDX_EL
        onehot_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY, ll).cuda(),
                            requires_grad=True)
        v_list.append(onehot_v)
    elif mode == 'sigmoid_idx':
        onehot_v = Variable(
            torch.randn(bz, ADV_SRC_LEN_TRY, len(vocab)).cuda() - 3,
            requires_grad=True)
        v_list.append(onehot_v)
    elif mode == 'linear_idx' or mode == 'softplus_idx':
        onehot_v = Variable(torch.randn(bz, ADV_SRC_LEN_TRY,
                                        len(vocab)).cuda(),
                            requires_grad=True)
        v_list.append(onehot_v)

    #latent_v_ori = Variable(torch.randn(bz, 1, HIDDEN_SIZE).cuda(), requires_grad = True)
    #latent_v = latent_v_ori.repeat(1, max_len, 1)
    lis_tgt_w = [l + ['<pad>'] * (max_len + 1 - len(l)) for l in lis_tgt_w]
    lis_tgt_idx = [[vocab_inv[w] for w in l] for l in lis_tgt_w]
    tgt_mb = Variable(torch.LongTensor(lis_tgt_idx).cuda())
    tgt_inputv = tgt_mb[:, :-1]
    tgt_targetv = tgt_mb[:, 1:]

    tgt_mask_ori = Variable(torch.FloatTensor([
        ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
    ]).cuda(),
                            requires_grad=False)

    lr = 1 * bz  #when the batch is larger than one, the learning rate becomes small, for mode 'embed' lr=1 works best
    for epoch in range(8):  #8!!
        opt_v = torch.optim.SGD(
            v_list, momentum=0.9, lr=lr,
            weight_decay=1e-5)  #the larger the better, it seems
        #opt_v = torch.optim.Adam(v_list, lr = 1e-4) #debug adam!
        #lr = lr * 0.6 #exp shows const 1 is better~
        for it in range(500 * epoch, 500 * (epoch + 1)):
            #print latent_v.size(), tgt_inputv.size(), tgt_len
            #latent_v = latent_v_ori.repeat(1, max_len, 1)
            if mode == 'embed':
                output, _ = m_encode_w_rnn(
                    embed_v,
                    init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM))
            elif mode == 'softmax_idx':
                #output, _ = m_encode_w_rnn(m_embed(F.softmax(onehot_v, dim = 2)), init_lstm_hidden(bz, HIDDEN_SIZE))
                output, _ = m_encode_w_rnn(
                    m_embed(softmax_idx_el_glue(onehot_v)),
                    init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM))
            elif mode == 'linear_idx':
                output, _ = m_encode_w_rnn(
                    m_embed(onehot_v),
                    init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM))
            elif mode == 'softplus_idx':
                output, _ = m_encode_w_rnn(
                    m_embed(F.softplus(onehot_v)),
                    init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM))
            elif mode == 'sigmoid_idx':
                output, _ = m_encode_w_rnn(
                    m_embed(F.sigmoid(onehot_v)).permute(1, 0, 2),
                    init_lstm_hidden(bz, HIDDEN_SIZE, layer_num=LAYER_NUM))
            output = output.permute(1, 0, 2)
            #latent = output[:, -1, :].unsqueeze(1).repeat(1, tgt_targetv.size(1), 1)
            #latent = Variable(torch.zeros(latent.size())).cuda() #debug!
            #print 'hn:', hn.size() #[1, bz, HIDDEN_SIZE]

            w_logit_rnn, attn_weights = decoder_forward(
                output, tgt_inputv, tgt_len)
            #w_logit_rnn = m_decode_w_rnn.decode(latent, tgt_inputv, tgt_len)
            flat_output = w_logit_rnn.view(-1, len(vocab))
            flat_target = tgt_targetv.contiguous().view(-1)
            flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output,
                                                       flat_target)
            tgt_mask = Variable(torch.FloatTensor([
                ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
            ]).cuda(),
                                requires_grad=False)
            w_logit_pred = torch.max(w_logit_rnn, dim=2)[1]
            for i in range(bz):
                for j in range(tgt_len[i]):
                    if w_logit_pred[i][j] == tgt_targetv[i][j]:
                        tgt_mask[i][j] = WEIGHT_LOSS_DECAY
                        #if tgt_targetv[i][j] != 0:
                        #    logger.info('setting to 0.01: |%s|', vocab[tgt_targetv[i][j]])
            batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask
            w_loss_rnn = torch.sum(batch_logpdf)

            w_loss_rnn_ori = torch.sum(flat_logpdf.view(bz, -1) * tgt_mask_ori)

            opt_v.zero_grad()
            (-w_loss_rnn / sum(tgt_len)).backward(retain_graph=True)

            if (mode == 'sigmoid_idx'
                    or mode == 'softmax_idx') and LASSO_LAMBDA > 0:
                if mode == 'sigmoid_idx':
                    sv = F.sigmoid(onehot_v)
                if mode == 'softmax_idx':
                    sv = F.softmax(onehot_v, dim=2)
                lasso_loss = LASSO_LAMBDA * torch.sum(sv,
                                                      dim=2).view(-1).mean()
                lasso_loss += (-LASSO_LAMBDA * 2) * (torch.max(
                    sv, dim=2)[0]).view(-1).mean()
                lasso_loss.backward(retain_graph=True)

            #print torch.max(torch.abs(onehot_v.grad)) #softmaxed gradient is very small
            opt_v.step()

            avg_loss = (-w_loss_rnn_ori / sum(tgt_len)).cpu().data[0]
            if 1 == 1 and it % 200 == 0:
                logger.info('it %d avg_loss: %f', it, avg_loss)

    if mode == 'embed':
        return embed_v, avg_loss
    elif mode == 'softmax_idx' or mode == 'linear_idx' or mode == 'softplus_idx' or mode == 'sigmoid_idx':
        return onehot_v, avg_loss
Exemplo n.º 6
0
def mle_train(batches, opt, do_train=True, do_log=True):
    for m in m_dict:
        if do_train == True:
            m_dict[m].train()
        else:
            m_dict[m].eval()

    all_loss = 0
    all_num = 0
    b_count = 0
    loss_sen = []
    for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches:
        #print src_w[0], src_w[1]; sys.exit(1)
        loss = 0
        b_count = b_count + 1
        bz = src_mb.size(0)
        all_num = all_num + sum(tgt_len)

        src_inputv = Variable(src_mb).cuda()
        tgt_inputv = Variable(tgt_mb[:, :-1]).cuda()
        tgt_targetv = Variable(tgt_mb[:, 1:]).cuda()
        tgt_mask = Variable(torch.FloatTensor([
            ([1] * l + [0] * (tgt_inputv.size(1) - l)) for l in tgt_len
        ]).cuda(),
                            requires_grad=False)

        #mask = Variable(mask_gen(b_len)).cuda()

        #size(batch, length)
        output, _ = m_encode_w_rnn_dp(
            m_embed_dp(idx2onehot(src_inputv, len(vocab))).permute(1, 0, 2),
            init_lstm_hidden(bz, HIDDEN_SIZE,
                             layer_num=LAYER_NUM))  #for parallel!
        output = output.permute(1, 0, 2)  #for parallel!
        w_logit_rnn, attn_weights = decoder_forward(output, tgt_inputv,
                                                    tgt_len)
        flat_output = w_logit_rnn.view(-1, len(vocab))
        flat_target = tgt_targetv.contiguous().view(-1)
        flat_logpdf = lib_pdf.logsoftmax_idxselect(flat_output, flat_target)
        batch_logpdf = flat_logpdf.view(bz, -1) * tgt_mask
        w_loss_rnn = torch.sum(batch_logpdf)
        loss_sen.extend(
            torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist())

        all_loss = all_loss + w_loss_rnn.data.item()

        if do_train == True:
            for m in m_dict.values():
                m.zero_grad()
            (-w_loss_rnn / sum(tgt_len)).backward()
            for m in m_dict.values():
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
            opt.step()

        if do_log == True and b_count % LOG_INTERVAL == 0:
            logger.info('avg loss at b: %d , %f', b_count,
                        all_loss * 1.0 / all_num)

    logger.info('all_num: %d', all_num)
    logger.info('sen_avg_loss: %f', np.mean(loss_sen))
    return float(all_loss * 1.0 / all_num)