def mle_train(batches,
              opt,
              do_train=True,
              do_log=True,
              do_adv=False,
              adv_config=None):
    for m in m_dict:
        if do_train == True:
            m_dict[m].train()
        else:
            m_dict[m].eval()

    if do_adv == True:
        inf_adv_batches = adv_config['inf_adv_batches']

    all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum = 0, 0, 0, 0, 0
    b_count, adv_b_count = 0, 0
    loss_sen = []
    for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches:
        #print src_w[0], src_w[1]; sys.exit(1)
        loss = 0
        b_count = b_count + 1
        bz = src_mb.size(0)
        all_num = all_num + sum(tgt_len)

        batch_logpdf = models.encoder_decoder_forward(src_mb, tgt_mb, tgt_len,
                                                      m_dict, adv_config)

        if 1 == 0:
            #===drawing debug===
            logprob_lis = []
            w_lis = []
            for i in range(bz / 2):
                ss = ""
                for j in range(tgt_len[i]):
                    ss = ss + tgt_w[i][
                        j +
                        1] + '(' + '%.2f' % batch_logpdf[i][j].item() + ') '
                    logprob_lis.append(batch_logpdf[i][j].item())
                    if batch_logpdf[i][j] < -1000:
                        print tgt_w[i][j + 1], batch_logpdf[i][j]
                    w_lis.append(tgt_w[i][j + 1])
                #print ss
            #torch.save(logprob_lis, 'figs/advtrain_wordlogps/naivetrainR' + str(ADV_RATIO) + '_' + SUC + '.data')
            #torch.save(w_lis, 'figs/advtrain_wordlogps/naivetrainR' + str(ADV_RATIO) + '_' + SUC + '_w.data')
            #sys.exit(1)
            #===end===

        #print torch.min(batch_logpdf)

        w_loss_rnn = torch.sum(batch_logpdf)
        loss_sen.extend(
            torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist())

        all_loss = all_loss + w_loss_rnn.data.item()

        if do_train == True:
            for m in m_dict.values():
                m.zero_grad()
            (-w_loss_rnn / sum(tgt_len)).backward()
            for m in m_dict.values():
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
            opt.step()

        if do_train == True and do_adv == True and b_count % ADV_RATIO == 0:
            target_mb = next(inf_adv_batches)
            src_mb, tgt_mb, tgt_len, src_lis, tgt_lis = get_adv_seq2seq_mb(
                target_mb, ADV_ATTACK, m_dict, adv_config)
            #print tgt_lis[0], tgt_mb[0], tgt_len[0]
            for m in m_dict:  #get_adv_seq2seq_mb could change the flags
                m_dict[m].zero_grad()
                m_dict[m].train()
            adv_batch_logpdf = models.encoder_decoder_forward(
                src_mb, tgt_mb, tgt_len, m_dict, adv_config)
            decay_co = get_decay_co(adv_batch_logpdf, tgt_len, adv_config)
            nodecay_sennum += np.sum(decay_co.cpu().numpy() == 1)
            adv_b_count += 1
            sen_logpdf = torch.sum(adv_batch_logpdf, dim=1) * decay_co
            (ADV_LAMBDA * torch.sum(sen_logpdf) / sum(tgt_len)).backward()
            for m in m_dict.values():
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
            opt.step()
            adv_all_loss += torch.sum(adv_batch_logpdf).item()
            adv_all_num += sum(tgt_len)

        if do_log == True and b_count % LOG_INTERVAL == 0:
            logger.info('avg loss at b: %d , %f', b_count,
                        all_loss * 1.0 / all_num)

    logger.info('all_num: %d', all_num)
    logger.info('sen_avg_loss: %f', np.mean(loss_sen))
    if do_adv == True:
        logger.info('adv_all_num: %d avg_nodecay_sennum: %d', adv_all_num,
                    float(nodecay_sennum * 1.0) / adv_b_count)
        logger.info('avg_adv_loss: %f', adv_all_loss / adv_all_num)
    if all_num != 0:
        return float(all_loss * 1.0 / all_num)
    else:
        return 0
    opt = torch.optim.SGD(all_params, momentum=0.9, lr=0, weight_decay=1e-5)
    loss_test = mle_train(batches_test,
                          opt,
                          do_train=False,
                          adv_config=adv_config)
    logger.info('test PPL: %f log-likelihood: %f', math.exp(-loss_test),
                loss_test)

    res = {'loss_test': loss_test, 'mb_lis': []}
    batches_test = DialogueBatches([TEST_FN],
                                   BATCH_SIZE,
                                   SRC_SEQ_LEN,
                                   TGT_SEQ_LEN,
                                   vocab_inv,
                                   DATA_CONFIG,
                                   his_len=HIS)
    for src_mb, tgt_mb, tgt_len, src_w, tgt_w in batches_test:
        batch_logpdf = models.encoder_decoder_forward(src_mb, tgt_mb, tgt_len,
                                                      m_dict, adv_config)
        res['mb_lis'].append(
            (src_mb.detach().cpu(), tgt_mb.detach().cpu(), tgt_len, src_w,
             tgt_w, batch_logpdf.detach().cpu()))

    res_save_fn = 'figs/post_mal_advtrain/' + DATA_SET + '_IT' + str(
        TEST_ITER) + '_NEGLR' + str(NEG_LR) + '_POSLR' + str(
            POS_LR) + '_FAVOID' + str(FREQ_AVOID) + '_ILM' + str(
                ADV_I_LM_FLAG) + '_MFOCUS' + str(
                    MIDDLE_FOCUS) + '_CARE' + ADV_CARE_MODE + '_testlogp.save'
    print 'saving to res_save_fn:', res_save_fn
    torch.save(res, res_save_fn)
def adv_train(adv_batches,
              positive_batches,
              opts,
              do_train=True,
              do_log=True,
              do_adv=False,
              adv_config=None):
    #adv_train put adversarial training first
    #assert(do_train == True)
    opt_pos, opt_neg = opts
    all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum, success0_co, success1_co = 0, 0, 0, 0, 0, 0, 0
    b_count, adv_b_count, adv_sen_count = 0, 0, 0
    attack_results = []
    loss_sen = []
    all_target_set = {}
    for target_mb in adv_batches:
        adv_src_mb, adv_tgt_mb, adv_tgt_len, adv_src_lis, adv_tgt_lis = get_adv_seq2seq_mb(
            target_mb, ADV_ATTACK, m_dict, adv_config)
        for l in adv_tgt_lis:
            all_target_set[' '.join(l[1:])] = True
        nondecay_lis_now = []
        co_now = 0
        while 1 == 1:
            #print tgt_lis[0], tgt_mb[0], tgt_len[0]
            for m in m_dict:  #get_adv_seq2seq_mb could change the flags
                m_dict[m].zero_grad()
                m_dict[m].train()
            aux = {}
            adv_batch_logpdf = models.encoder_decoder_forward(adv_src_mb,
                                                              adv_tgt_mb,
                                                              adv_tgt_len,
                                                              m_dict,
                                                              adv_config,
                                                              aux_return=aux)
            w_logit_rnn = aux['w_logit_rnn']
            decay_co = get_decay_co(adv_batch_logpdf, w_logit_rnn, adv_tgt_mb,
                                    adv_tgt_len, adv_config)
            attack_results.append(
                (adv_src_mb, adv_tgt_mb, adv_tgt_len, adv_src_lis, adv_tgt_lis,
                 decay_co, adv_batch_logpdf))
            final_mask = torch.FloatTensor(adv_batch_logpdf.size()).cuda()
            final_mask[:] = 1
            if MIDDLE_FOCUS == True:
                for i in range(len(adv_tgt_len)):
                    final_mask[i][0] = 0
                    final_mask[i][adv_tgt_len[i] - 1] = 0
            if FREQ_AVOID == True:
                for i in range(len(adv_tgt_len)):
                    for j in range(adv_tgt_len[i]):
                        if adv_tgt_lis[i][j + 1] in FREQ_AVOID_LIS:
                            final_mask[i][j] = 0
                """
                for i in range(20):
                    print i, 'adv_tgt_lis:', adv_tgt_lis[i]
                    print 'mask:', final_mask[i]
                    print 'adv_tgt_len', adv_tgt_len[i]
                    print 'adv_batch_logpdf', adv_batch_logpdf[i]
                    print '*final_mask', (adv_batch_logpdf * final_mask)[i]
                """

            sen_logpdf = torch.sum(adv_batch_logpdf * final_mask,
                                   dim=1) * decay_co
            (ADV_LAMBDA * torch.sum(sen_logpdf) / sum(adv_tgt_len)).backward()
            for m in m_dict.values():
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
            if do_train == True:
                opt_neg.step()
            """
            #just for debug!!
            for i in range(decay_co.size(0)):
                if decay_co[i] == 1:
                    _, tgt_w, _ = target_mb
                    s_lis = adv_src_mb[i].cpu().numpy().tolist()
                    print 'hit! target:', ' '.join(tgt_w[i]), 't_input:', ' '.join([vocab[ii] for ii in s_lis])
            print 'hit hit hit in first mb'
            """

            nondecay_num_now = np.sum(decay_co.cpu().numpy() == 1)
            #logger.info('co_now: %d nondecay_num_now: %d', co_now, nondecay_num_now)
            nondecay_lis_now.append(nondecay_num_now)
            if co_now == 0:
                logger.info('displaying hit targets')
                for i in range(len(adv_tgt_len)):
                    if decay_co[i] == 1:
                        print "hit! target:", ' '.join(
                            adv_tgt_lis[i]), 'trigger:', ' '.join(
                                adv_src_lis[i])
            if co_now == 0:
                adv_sen_count += adv_src_mb.size(0)
                adv_all_loss += torch.sum(adv_batch_logpdf).item()
                adv_all_num += sum(adv_tgt_len)
                nodecay_sennum += nondecay_num_now
                adv_b_count += 1
                if do_train == False:
                    break
            co_now += 1

            for kk in range(POSITIVE_RATIO):
                src_mb, tgt_mb, tgt_len, src_w, tgt_w = positive_batches.next()
                loss = 0
                b_count = b_count + 1
                bz = src_mb.size(0)
                all_num = all_num + sum(tgt_len)

                batch_logpdf = models.encoder_decoder_forward(
                    src_mb, tgt_mb, tgt_len, m_dict, adv_config)

                w_loss_rnn = torch.sum(batch_logpdf)
                loss_sen.extend(
                    torch.sum(batch_logpdf,
                              dim=1).detach().cpu().numpy().tolist())

                all_loss = all_loss + w_loss_rnn.data.item()

                if do_train == True:
                    for m in m_dict.values():
                        m.zero_grad()
                    (-w_loss_rnn / sum(tgt_len)).backward()
                    for m in m_dict.values():
                        torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
                    opt_pos.step()
            if nondecay_num_now == 0 or co_now > 20:
                if nondecay_num_now <= 0: success0_co += 1
                if nondecay_num_now <= 1: success1_co += 1
                break
        logger.info('current mb nondecay_lis_now(through each update): %s',
                    str(nondecay_lis_now))

    logger.info('all_num: %d', all_num)
    if all_num > 0:
        logger.info('sen_avg_loss: %f', np.mean(loss_sen))
    if do_adv == True:
        logger.info(
            'iter: %d adv_all_num: %d avg_nodecay_sennum: %d success0_rate: %f success1_rate: %f',
            adv_config['iter_now'], adv_all_num,
            float(nodecay_sennum * 1.0) / adv_b_count,
            float(success0_co * 100.0 / adv_b_count),
            float(success1_co * 100.0 / adv_b_count))
        logger.info('debug for avg_nodecay_sennum: %d / %d', nodecay_sennum,
                    adv_b_count)
        logger.info('avg_adv_loss: %f', adv_all_loss / adv_all_num)

    res = {
        'positive_avg_loss': 0 if all_num == 0 else all_loss * 1.0 / all_num,
        'nodecay_sennum': nodecay_sennum,
        'adv_sen_count': adv_sen_count,
        'attack_results': attack_results,
        'all_target_set': all_target_set,
    }
    return res
def neg_d_train(aux_batches,
                train_batches,
                opts,
                do_train=True,
                do_log=True,
                do_adv=False,
                adv_config=None):
    #adv_train put adversarial training first
    #assert(do_train == True)
    opt_pos, opt_neg, opt_d = opts
    all_loss, all_num, adv_all_loss, adv_all_num, nodecay_sennum, success0_co, success1_co = 0, 0, 0, 0, 0, 0, 0
    attack_results = []
    loss_sen = []
    all_target_set = {}
    inf_pos_batches, inf_d_batches = aux_batches

    d_b_co, pos_b_co, neg_b_co = 0, 0, 0
    sample_mb_longhis = []
    stat_dic = MyStatDic()
    for train_mb in train_batches:
        neg_b_co += 1
        src_mb, tgt_mb, tgt_len, src_w, tgt_w = train_mb
        bz = src_mb.size(0)
        res_sample = models.get_samples([train_mb], TRAIN_SAMPLE_TYPE, m_dict,
                                        adv_config)
        sample_mb, samples_w = res_sample['raw_sample_lis'][0], res_sample[
            'sample_lis']

        neg_lambda = [0 for kk in range(len(samples_w))]
        sample_mb_longhis.append(samples_w)
        if len(sample_mb_longhis) > 200:
            sample_mb_longhis = sample_mb_longhis[1:]
            d_r = make_ratio_dict(sample_mb_longhis)
            for i in range(BATCH_SIZE):
                #print samples_w[i]
                if d_r[' '.join(samples_w[i])] > R_THRES:
                    neg_lambda[i] = 1
                    #print 'found frequent response:', ' '.join(samples_w[i]), 'ratio:', d_r[' '.join(samples_w[i])]
        neg_lambda = torch.FloatTensor(neg_lambda).cuda()

        sample_tgt_mb, sample_tgt_len = form_tgtmb(sample_mb)
        for m in m_dict.values():
            m.zero_grad()
            m.train()
        aux = {}
        adv_batch_logpdf = models.encoder_decoder_forward(src_mb,
                                                          sample_tgt_mb,
                                                          sample_tgt_len,
                                                          m_dict,
                                                          adv_config,
                                                          aux_return=aux)
        w_logit_rnn = aux['w_logit_rnn']

        final_mask = torch.FloatTensor(adv_batch_logpdf.size()).cuda()
        final_mask[:] = 1
        if FREQ_AVOID == True:
            for i in range(len(sample_tgt_len)):
                final_mask[i][sample_tgt_len[i] - 1] = FREQ_AVOID_SCAL
            #for i in range(5):
            #    print sample_tgt_mb[i]
            #    print adv_batch_logpdf[i]
            #    print final_mask[i]
            #sys.exit(1)

        neg_logpdf = torch.sum(adv_batch_logpdf * final_mask,
                               dim=1) * neg_lambda
        if do_train == True:
            (torch.sum(neg_logpdf) / sum(sample_tgt_len)).backward()
            for m in m_dict.values():
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
            opt_neg.step()

        stat_dic.append_dict({'neg_lambda': neg_lambda.mean().item()})
        stat_dic.append_dict({
            'neg_loss':
            (torch.sum(adv_batch_logpdf).item() / sum(sample_tgt_len))
        })
        if neg_b_co % 2000 == 0 and DEBUG_INFO == True:
            print 'peek mb sample!'
            for i in range(sample_tgt_mb.size(0) / 5):
                logger.info(
                    'neg_lambda: %f %s', neg_lambda[i].item(), ' '.join([
                        vocab[sample_tgt_mb[i][j]]
                        for j in range(sample_tgt_len[i] + 1)
                    ]))
            #wait = raw_input("PRESS ENTER")

        src_mb, tgt_mb, tgt_len, src_w, tgt_w = train_mb
        for kk in range(D_RATIO):
            res_d = d_train_mb(inf_d_batches.next(), opt_d, do_train=do_train)
            d_b_co += 1
            stat_dic.append_dict(res_d, keys=['d_loss', 'd_acc_ratio'])

        for kk in range(POSITIVE_RATIO):
            src_mb, tgt_mb, tgt_len, src_w, tgt_w = inf_pos_batches.next()
            pos_b_co += 1
            bz = src_mb.size(0)

            for m in m_dict.values():
                m.train()
                m.zero_grad()

            batch_logpdf = models.encoder_decoder_forward(
                src_mb, tgt_mb, tgt_len, m_dict, adv_config)

            w_loss_rnn = torch.sum(batch_logpdf)
            loss_sen.extend(
                torch.sum(batch_logpdf, dim=1).detach().cpu().numpy().tolist())

            #all_loss = all_loss + w_loss_rnn.data.item()
            stat_dic.append_dict(
                {'pos_loss': (-w_loss_rnn.item() / sum(tgt_len))})

            if do_train == True:
                (-w_loss_rnn / sum(tgt_len)).backward()
                for m in m_dict.values():
                    torch.nn.utils.clip_grad_norm_(m.parameters(), 5)
                opt_pos.step()

        if neg_b_co % 200 == 0:
            stat_dic.log_mean(last_num=200,
                              log_pre='it{} neg_b_co: {}'.format(
                                  adv_config['iter_now'], neg_b_co))

    return stat_dic