Ejemplo n.º 1
0
def getSingleTrainData():
    '''
    [
            {
             "citStr":"" 引用的作者和年份,
             "context":"", 整个引用片段
             "up_source_tokens":"",
             "down_source_tokens":"",
             "target_tokens":""
             "citations":[
                            {
                            "up_source_tokens":"",
                            "down_source_tokens":"",
                            "target_tokens":""
                            }
                           ...
                          ]
            }
            ......

        ]
    查找相似citation
    :return:
    '''
    datas = pickleload("../data2/train_data.pkl", "./data2/train_data.pkl")
    # datas = datas[len(datas)-1000:len(datas)]
    print(len(datas))
    new_datas = copy(datas)
    train_datas = []
    for i in tqdm(range(len(datas))):
        data = datas[i]
        target = data_process(data["target_tokens"])

        #计算citation
        citations = data["citations_tokens"]
        scores = []
        for index in range(len(citations)):
            ciation = citations[index]
            cit_target = data_process(ciation["target_tokens"])
            score = test_bleu(cit_target, target, 1)
            scores.append(score)
            new_datas[i]['citations_tokens'][index]["bleu1_score"] = score

            dic = {}
            dic['up_source'] = data_process(data["up_source_tokens"])
            dic['down_source'] = data_process(data["down_source_tokens"])
            dic['target'] = data_process(data["target_tokens"])
            dic['cit_up_source'] = data_process(ciation['up_source_tokens'])
            dic['cit_down_source'] = data_process(
                ciation['down_source_tokens'])
            dic['cit_target'] = data_process(ciation['target_tokens'])
            dic['bleu1_score'] = score
            if score == 1:
                continue
            train_datas.append(copy(dic))
    print("训练样本的数据量为:", len(train_datas))
    picklesave(train_datas, "./train_data/single_train_data.pkl",
               "single_train_data.pkl")
Ejemplo n.º 2
0
def getRandomData():
    import numpy as np
    datas = pickleload("../data2/train_data2.pkl", "./data2/train_data2.pkl")
    new_datas = []
    ids = range(len(datas))
    permutation = np.random.permutation(ids)
    for i, id in enumerate(permutation):
        new_datas.append(datas[id])
    picklesave(new_datas, "../data2/random_train_data.pkl",
               "./data2/random_train_data.pkl")
Ejemplo n.º 3
0
def getWord2index():
    '''
    [
            {
             "citStr":"" 引用的作者和年份,
             "context":"", 整个引用片段
             "up_source_tokens":"",
             "down_source_tokens":"",
             "target_tokens":""
             "citations":[
                            {
                            "up_source_tokens":"",
                            "down_source_tokens":"",
                            "target_tokens":""
                            }
                           ...
                          ]
            }
            ......

        ]
    查找相似citation
    :return:
    '''
    datas = pickleload("../data2/train_data.pkl", "./data2/train_data.pkl")
    # datas = datas[len(datas)-1000:len(datas)]
    print(len(datas))
    tokenDic = {}
    for i in tqdm(range(len(datas))):
        data = datas[i]
        target = data_process(data["target_tokens"]).split(" ")
        up_source = data_process(data["up_source_tokens"]).split(" ")
        down_source = data_process(data["down_source_tokens"]).split(" ")
        word_lis = target + up_source + down_source
        for token in word_lis:
            if token not in tokenDic:
                tokenDic[token] = 1
            else:
                tokenDic[token] += 1

    index = 2
    word2index = {}
    for key, value in tokenDic.items():
        if value > 1:
            word2index[key] = index
            index += 1
    word2index['<padding>'] = 0
    word2index['<unknow>'] = 1
    word2index['<CLS>'] = index
    word2index['<DSP>'] = index + 1
    word2index['<MASK>'] = index + 2
    print(len(word2index), "  /  ", len(tokenDic), "个token")
    picklesave(word2index, './word_vec/word2index.pkl', "word2index.pkl")
Ejemplo n.º 4
0
def manual_label():
    datas = pickleload("../data2/train_data2.pkl", "./data2/train_data.pkl")
    # golden_train_datas = pickleload("../data/golden_train_data.pkl", "./data/golden_train_data.pkl")
    print(len(datas))
    train_datas = []
    flag_pairs = {}
    for i in range(len(datas)):
        data = datas[i]
        target = data_process(data["target_tokens"])
        # 计算citation
        citations = data["citations_tokens"]
        flag = 0
        for index in range(len(citations)):
            citation = citations[index]
            cand_cit = data_process(citation["target_tokens"])
            if cand_cit + target not in flag_pairs.keys():
                print("进程:", i, "/", len(datas), "  ", index, "/",
                      len(citations))
                print("target:", target)
                print("candidate:", cand_cit)
                label = input("标签:")
                if str(label) == "1":
                    citations[index]['label'] = 1
                    flag = 1
                else:
                    citations[index]['label'] = 0
                flag_pairs[cand_cit + target] = citations[index]['label']
                flag_pairs[target + cand_cit] = citations[index]['label']
            else:
                if flag_pairs[cand_cit + target] == 1:
                    citations[index]['label'] = 1
                    flag = 1
                else:
                    citations[index]['label'] = 0
        picklesave(flag_pairs, "../data/flag_pairs.pkl",
                   "./data/flag_pairs.pkl")
        if flag == 1:
            new_data = datas[i]
            new_data["citations_tokens"] = citations
            train_datas.append(new_data)
            picklesave(train_datas, "../data/golden_train_data.pkl",
                       "./data/golden_train_data.pkl")
Ejemplo n.º 5
0
def getIdf2():
    citationDic = pickle.load(open("./data/processed_data3.pkl", "rb"))
    all_count = 0
    tokenidf_dic = {}
    for key, value in tqdm(citationDic.items()):
        all_count += len(value)
        for v in value:
            context = v['context']
            word_lis = nltk.word_tokenize(context)
            dic = {}
            for token in word_lis:
                if token not in dic:
                    dic[token] = 1

            for key in dic.keys():
                if key not in tokenidf_dic:
                    tokenidf_dic[key] = 1
                else:
                    tokenidf_dic[key] += 1
    new_dic = {}
    for key, value in tokenidf_dic.items():
        new_dic[key] = math.log10(all_count / value)
    picklesave(new_dic, './data2/idf2.pkl', "idf")
Ejemplo n.º 6
0
def getIdf():
    datas = pickleload("./data2/train_data.pkl", "./data/train_data.pkl")
    all_count = len(datas)
    print(len(datas))
    tokenidf_dic = {}
    for data in tqdm(datas):
        up_source_tokens = process(data["up_source_tokens"]).split(" ")
        down_source_tokens = process(data["down_source_tokens"]).split(" ")
        target_tokens = process(data["target_tokens"]).split(" ")
        dic = {}
        for token in up_source_tokens + down_source_tokens + target_tokens:
            if token not in dic:
                dic[token] = 1

        for key in dic.keys():
            if key not in tokenidf_dic:
                tokenidf_dic[key] = 1
            else:
                tokenidf_dic[key] += 1
    new_dic = {}
    for key, value in tokenidf_dic.items():
        new_dic[key] = math.log10(all_count / value)
    picklesave(new_dic, './data2/idf.pkl', "idf")
Ejemplo n.º 7
0
def all_doubletrainKey(args):
    data = pickleload(
        '../Retrieval/train_data/small_pairs_random_train_data.pkl',
        "small_pairs_random_train_data")
    dev_data = pickleload("../data2/random_train_data.pkl", "dev_data")
    train_data = data[0] + data[1] + data[2] + data[3]
    dev_data = dev_data[len(dev_data) * 4 // 5:len(dev_data)]

    batch = Batch(args)
    # source_embedding = pickleload(args.source_emb_mat_pkl, "source_emb_mat_pkl")
    word2index = pickleload("./word_vec/word2index.pkl", "word2index.pkl")
    input_vec = len(word2index)

    train_batches = batch.double_train_batch(train_data, args.context_limit,
                                             args.num_epoches, args.batch_size)

    log_msg = "输入词空间大小:%d" % (input_vec)
    logger.info(log_msg)
    print(log_msg)

    transform = Transformer(args, input_vec)

    if torch.cuda.is_available():
        transform = transform.cuda()

    transform.load_state_dict(
        torch.load("./modelsave/" + "TransformModel0.pkl"))

    model = AllClassifyGetKeyWords(args, transform)

    model = model.cuda()
    if args.loadmodel == True:
        model.load_state_dict(torch.load("./modelsave/" + args.loadmodelName))
    # for param in model.parameters():
    #     param.data.uniform_(-0.08, 0.08)
    #     param.data.uniform_(-0.08, 0.08)

    parameters_trainable = list(
        filter(lambda p: p.requires_grad, model.parameters()))

    if args.optim == "Adadelta":
        optimizer = torch.optim.Adadelta(parameters_trainable,
                                         lr=args.learning_rate,
                                         weight_decay=args.init_weight_decay)
    elif args.optim == "Adam":
        optimizer = torch.optim.Adam(parameters_trainable,
                                     lr=args.learning_rate,
                                     weight_decay=args.init_weight_decay)
    elif args.optim == "SGD":
        optimizer = torch.optim.SGD(parameters_trainable,
                                    lr=args.learning_rate,
                                    weight_decay=args.init_weight_decay)

    if args.loadmodel == True:
        model.load_state_dict(torch.load("./modelsave/" + args.loadmodelName))
    # 打印参数:
    log_msg = "优化函数:%s \n 学习率:%s \n 隐藏层:%s\n 保存模型名称:%s \n" % (
        args.optim, args.learning_rate, args.d_model, args.modelName)
    # print("dropout:", args.dropout)
    logger.info(log_msg)
    print(log_msg)

    set_epoch = 1
    pbar = tqdm(total=len(train_data) * args.num_epoches // args.batch_size +
                1)

    def loss_func(high_out, low_out, seleout11, seleout12, seleout21,
                  seleout22):
        ones = torch.ones(high_out.size(0), 1).cuda()
        ones1 = 7 * torch.ones(high_out.size(0), 1).cuda()
        loss = torch.mean(ones - high_out + low_out) + torch.mean((ones1 - seleout11)*(ones1 - seleout11)) + torch.mean((ones1 - seleout12)*(ones1 - seleout12)) + \
               torch.mean((ones1 - seleout21)*(ones1 - seleout21)) + torch.mean((ones1 - seleout22)*(ones1 - seleout22))
        return F.relu(loss), torch.mean(ones - high_out + low_out)

    print_loss_total = 0
    old_accu = 0
    print_loss_total2 = 0
    for train_step, (train_batch, epoch) in enumerate(train_batches):
        pbar.update(1)
        high_context_idxs = train_batch['high_cit_context_idxs']
        high_seg_ids = train_batch['high_seg_indexs']
        low_context_idxs = train_batch['low_cit_context_idxs']
        low_seg_ids = train_batch['low_seg_indexs']
        high_source_context_idxs = train_batch['high_source_context_idxs']
        high_source_seg_indexs = train_batch['high_source_seg_indexs']
        low_source_context_idxs = train_batch['low_source_context_idxs']
        low_source_seg_indexs = train_batch['low_source_seg_indexs']

        high_context_mask = torch.Tensor(
            np.array([
                list(map(function, xx))
                for xx in high_context_idxs.data.numpy()
            ],
                     dtype=np.float)).cuda()
        low_context_mask = torch.Tensor(
            np.array([
                list(map(function, xx))
                for xx in low_context_idxs.data.numpy()
            ],
                     dtype=np.float)).cuda()
        high_source_context_mask = torch.Tensor(
            np.array([
                list(map(function, xx))
                for xx in high_source_context_idxs.data.numpy()
            ],
                     dtype=np.float)).cuda()
        low_source_context_mask = torch.Tensor(
            np.array([
                list(map(function, xx))
                for xx in low_source_context_idxs.data.numpy()
            ],
                     dtype=np.float)).cuda()

        high_context_idxs = Variable(high_context_idxs).cuda()
        high_seg_ids = Variable(high_seg_ids).cuda()
        low_context_idxs = Variable(low_context_idxs).cuda()
        low_seg_ids = Variable(low_seg_ids).cuda()
        high_source_context_idxs = Variable(high_source_context_idxs).cuda()
        high_source_seg_indexs = Variable(high_source_seg_indexs).cuda()
        low_source_context_idxs = Variable(low_source_context_idxs).cuda()
        low_source_seg_indexs = Variable(low_source_seg_indexs).cuda()

        out1, seleout11, seleout12 = model.forward(high_context_idxs,
                                                   high_seg_ids,
                                                   high_context_mask,
                                                   high_source_context_idxs,
                                                   high_source_seg_indexs,
                                                   high_source_context_mask)
        out2, seleout21, seleout22 = model.forward(low_context_idxs,
                                                   low_seg_ids,
                                                   low_context_mask,
                                                   low_source_context_idxs,
                                                   low_source_seg_indexs,
                                                   low_source_context_mask)
        # Get loss
        optimizer.zero_grad()
        #out1:batch * num_target * word_vec
        #out2:batch * 2
        loss, loss2 = loss_func(out1, out2, seleout11, seleout12, seleout21,
                                seleout22)
        # Backward propagation
        loss.backward()
        optimizer.step()
        loss_value = loss.data.item()
        print_loss_total += loss_value
        print_loss_total2 += loss2.data.item()
        del out1, out2
        if train_step % 100 == 0:
            log_msg = 'Epoch: %d, Train_step %d  loss1: %.4f, loss2:%.4f' % (
                epoch, train_step, print_loss_total / 100,
                print_loss_total2 / 100)
            logger.debug(log_msg)
            print(log_msg)
            print_loss_total = 0
            print_loss_total2 = 0
        if epoch == set_epoch:
            set_epoch += 1
            dev_batches = batch.dev_batch(dev_data, args.context_limit)
            result_dic = {}
            true_label_dic = {}
            for dev_step, dev_batch in enumerate(dev_batches):
                context_idxs = dev_batch['context_idxs']
                source_context_idxs = dev_batch['source_context_idxs']
                seg_indexs = dev_batch['seg_indexs']
                source_seg_indexs = dev_batch['source_seg_indexs']
                ref_labels = dev_batch['ref_labels']
                id = dev_batch['id']

                context_mask = torch.Tensor(
                    np.array([
                        list(map(function, xx))
                        for xx in context_idxs.data.numpy()
                    ],
                             dtype=np.float)).cuda()
                source_context_mask = torch.Tensor(
                    np.array([
                        list(map(function, xx))
                        for xx in source_context_idxs.data.numpy()
                    ],
                             dtype=np.float)).cuda()

                context_idxs = Variable(context_idxs).cuda()
                seg_indexs = Variable(seg_indexs).cuda()
                source_context_idxs = Variable(source_context_idxs).cuda()
                source_seg_indexs = Variable(source_seg_indexs).cuda()
                out, seleout1, seleout2 = model.forward(
                    context_idxs, seg_indexs, context_mask,
                    source_context_idxs, source_seg_indexs,
                    source_context_mask)
                # Get loss
                if id not in result_dic:
                    result_dic[id] = []
                    result_dic[id].append(out.cpu().data)
                    true_label_dic[id] = ref_labels
                else:
                    result_dic[id].append(out.cpu().data)
                del out
            picklesave(result_dic, "./modelsave/all_dev_result_dic22.pkl",
                       "./modelsave/result_dic.pkl")
            picklesave(true_label_dic,
                       "./modelsave/all_dev_true_label_dic22.pkl",
                       "./modelsave/true_label_dic.pkl")
            keys = result_dic.keys()
            MAPS = 0
            precisions = 0
            recalls = 0
            for key in keys:
                out = torch.cat(result_dic[key], dim=0)
                predict_index = torch.topk(out, 2,
                                           dim=0)[1].squeeze(1).data.numpy()
                # print("预测标签:",predict_index)
                precision, recall, MAP = cal_MAP(true_label_dic[key],
                                                 predict_index)
                MAPS += MAP
                precisions += precision
                recalls += recall

            MAPS /= len(dev_data)
            precisions /= len(dev_data)
            recalls /= len(dev_data)
            all_loss = MAPS
            if all_loss > old_accu:
                old_accu = all_loss
                torch.save(model.state_dict(),
                           "./modelsave/max" + args.modelName)
                best_epoch = epoch
            # else:
            #     args.learning_rate = args.learning_rate / 2.0
            #     if args.learning_rate <= 1e-6:
            #         args.learning_rate = 1e-6
            #     if args.optim == "Adadelta":
            #         optimizer = torch.optim.Adadelta(parameters_trainable, lr=args.learning_rate,
            #                                          weight_decay=args.init_weight_decay)
            #     elif args.optim == "Adam":
            #         optimizer = torch.optim.Adam(parameters_trainable, lr=args.learning_rate,
            #                                      weight_decay=args.init_weight_decay)
            #     elif args.optim == "SGD":
            #         optimizer = torch.optim.SGD(parameters_trainable, lr=args.learning_rate,
            #                                     weight_decay=args.init_weight_decay)
            log_msg = '\n验证集的MAP为: %.4f  P为: %.4f  R为: %.4f\n 取得最小loss的epoch为:%d' % (
                all_loss, precisions, recalls, best_epoch)
            logger.info(log_msg)
            print(log_msg)
            # 实时保存每个epoch的模型
            torch.save(model.state_dict(), "./modelsave/" + args.modelName)
    torch.save(model.state_dict(), "./modelsave/" + args.modelName)
    pbar.close()
Ejemplo n.º 8
0
def getSmallPairsTrainData():
    import random
    '''
    [
            {
             "citStr":"" 引用的作者和年份,
             "context":"", 整个引用片段
             "up_source_tokens":"",
             "down_source_tokens":"",
             "target_tokens":""
             "citations":[
                            {
                            "up_source_tokens":"",
                            "down_source_tokens":"",
                            "target_tokens":""
                            }
                           ...
                          ]
            }
            ......

        ]
    查找相似citation
    :return:
    '''
    datas = pickleload("../data2/train_data2.pkl", "./data2/train_data2.pkl")
    idf_dic = pickleload("../data2/idf.pkl", "idf.pkl")
    # datas = datas[len(datas)-1000:len(datas)]
    print(len(datas))
    train_datas = []
    train_datas2 = []
    train_spill = []
    q_id = 0
    for i in tqdm(range(len(datas))):
        data = datas[i]
        target = data_process(data["target_tokens"])
        # 计算citation
        citations = data["citations_tokens"]
        scores = []
        if len(target) < 50:
            continue
        for index in range(len(citations)):
            ciation = citations[index]
            cit_target = data_process(ciation["target_tokens"])
            if target == cit_target or len(cit_target) < 50:
                scores.append(0)
            else:
                score = getSVMScore(idf_dic, process_kuohao(target),
                                    process_kuohao(cit_target))
                scores.append(score)

        sorted_scores = sorted(scores, reverse=True)
        best_indexs = []
        for j in range(len(sorted_scores)):
            if sorted_scores[j] > 0.1 and j <= 5:
                best_index = scores.index(sorted_scores[j])
                best_indexs.append(best_index)
        if len(best_indexs) == len(citations):
            continue
        for best_index in best_indexs:
            train_data = {}
            train_data['up_source'] = data_process(data["up_source_tokens"])
            train_data['down_source'] = data_process(
                data["down_source_tokens"])
            train_data['target'] = data_process(data["target_tokens"])

            high_dic = {}
            high_dic['cit_up_source'] = data_process(
                citations[best_index]['up_source_tokens'])
            high_dic['cit_down_source'] = data_process(
                citations[best_index]['down_source_tokens'])
            high_dic['cit_target'] = data_process(
                citations[best_index]['target_tokens'])
            high_dic['bleu1_score'] = scores[best_index]

            # for k in range(len(best_indexs)):
            #     print("target:", train_data['target'])
            #     print("cit_target:", data_process(citations[best_indexs[k]]['target_tokens']))
            #     print("score:", sorted_scores[k])
            #     print("\n")
            # print(len(best_indexs), "  /   ", len(citations))
            # print("---------------------------------------------")
            low_index = random.randint(0, len(scores) - 1)
            while low_index in best_indexs:
                low_index = random.randint(0, len(scores) - 1)
            if scores[best_index] == scores[low_index] or scores[
                    best_index] == 1.0:
                continue
            low_dic = {}
            low_dic['cit_up_source'] = data_process(
                citations[low_index]['up_source_tokens'])
            low_dic['cit_down_source'] = data_process(
                citations[low_index]['down_source_tokens'])
            low_dic['cit_target'] = data_process(
                citations[low_index]['target_tokens'])
            low_dic['bleu1_score'] = scores[low_index]
            if low_dic['cit_target'] == train_data['target']:
                continue
            train_data['high_dic'] = high_dic
            train_data['low_dic'] = low_dic
            train_spill.append(train_data)

        if i in [
                len(datas) // 5,
                len(datas) * 2 // 5,
                len(datas) * 3 // 5,
                len(datas) * 4 // 5,
                len(datas) - 1
        ]:
            train_datas.append(train_spill)
            print(len(train_spill))
            train_spill = []

    print(len(train_datas))
    print(len(train_datas2))  #26933
    print("训练样本的数据量为:", len(train_datas))
    picklesave(train_datas, "./train_data/small_pairs_train_data.pkl",
               "small_pairs_train_data.pkl")
Ejemplo n.º 9
0
def test(args):
    args.dropout = 0.0
    data = pickleload("../data2/random_train_data.pkl", "traindata")
    dev_data = data[len(data) * 4 // 5:len(data)]
    # dev_data = data[2000: 4000]

    batch = Batch(args)
    word2index = pickleload("./word_vec/word2index.pkl", "word2index.pkl")
    input_vec = len(word2index)

    dev_batches = batch.dev_batch(dev_data, args.context_limit)

    log_msg = "输入词空间大小:%d" % (input_vec)
    logger.info(log_msg)
    print(log_msg)

    transform = Transformer(args, input_vec)
    # transform.load_state_dict(torch.load("./modelsave/" + "TransformModel0.pkl"))
    if torch.cuda.is_available():
        transform = transform.cuda()

    # model = Classify(args, transform)
    model = Classify(args, transform)

    #if args.loadmodel ==True:
    model.load_state_dict(torch.load("./modelsave/" + "maxclassifyModel2.pkl"))

    if torch.cuda.is_available():
        model = model.cuda()

    # 打印参数:
    log_msg = "模型名称:%s \n" % (args.loadmodelName)
    logger.info(log_msg)
    print(log_msg)

    result_dic = {}
    true_label_dic = {}
    for dev_step, dev_batch in enumerate(dev_batches):
        context_idxs = dev_batch['context_idxs']
        seg_indexs = dev_batch['seg_indexs']
        cit_targets = dev_batch['cit_targets']
        target = dev_batch['targets']
        ref_labels = dev_batch['ref_labels']
        id = dev_batch['id']
        print(id)
        context_mask = torch.Tensor(
            np.array(
                [list(map(function, xx)) for xx in context_idxs.data.numpy()],
                dtype=np.float)).cuda()

        context_idxs = Variable(context_idxs).cuda()
        seg_indexs = Variable(seg_indexs).cuda()
        out = model.forward(context_idxs, seg_indexs, context_mask)
        # Get loss
        if id not in result_dic:
            result_dic[id] = []
            result_dic[id].append(out.cpu().data)
            true_label_dic[id] = ref_labels
        else:
            result_dic[id].append(out.cpu().data)
        del out
    picklesave(result_dic, "./modelsave/classifyModel2_predict.pkl",
               "./modelsave/result_dic.pkl")
    picklesave(true_label_dic, "./modelsave/classifyModel2_true.pkl",
               "./modelsave/true_label_dic.pkl")
    keys = result_dic.keys()
    MAPS = 0
    precisions = 0
    recalls = 0
    for key in keys:
        out = torch.cat(result_dic[key], dim=0)
        predict_index = torch.topk(out, 2, dim=0)[1].squeeze(1).data.numpy()
        # print("预测标签:",predict_index)
        precision, recall, MAP = cal_MAP(true_label_dic[key], predict_index)
        MAPS += MAP
        precisions += precision
        recalls += recall

    MAPS /= len(dev_data)
    precisions /= len(dev_data)
    recalls /= len(dev_data)
    print("MAP:%.4f  P:%.4f  R:%.4f" % (MAPS, precisions, recalls))
Ejemplo n.º 10
0
    def forward(self, sentence, sentence_Seg, sentence_mask, source_sentence,
                source_sentence_Seg, source_sentence_mask):
        inputs1 = self.transform.getTransforEmbedding(sentence, sentence_Seg,
                                                      sentence_mask)
        inputs2 = self.transform.getTransforEmbedding(source_sentence,
                                                      source_sentence_Seg,
                                                      source_sentence_mask)

        #方案一:
        # q_lens = torch.sum(sentence_mask, dim=1).type(torch.cuda.LongTensor)
        # q_len_max = int(torch.max(q_lens, dim=0)[0].cpu().data.numpy())
        # inputs1 = inputs1[:, 0:q_len_max, :]
        # sentence_mask = sentence_mask[:, 0:q_len_max]
        #
        # d_lens = torch.sum(source_sentence_mask, dim=1).type(torch.cuda.LongTensor)
        # d_len_max = int(torch.max(d_lens, dim=0)[0].cpu().data.numpy())
        # inputs2 = inputs2[:, 0:d_len_max, :]
        # source_sentence_mask = source_sentence_mask[:, 0:d_len_max]
        # sen_1 = inputs1[:, 0, :]
        # sen_2 = inputs1[:, 0, :]
        # sen_11 = self.self_attention(inputs1[:,1:,:], sen_1, sentence_mask[:,1:])
        # sen_22 = self.self_attention(inputs2[:,1:,:], sen_2, source_sentence_mask[:,1:])
        # sen_11 = torch.mean(inputs1[:,1:,:], dim=1)#inputs1[:,1:,:]#
        # sen_22 = torch.mean(inputs2[:,1:,:], dim=1)#inputs2[:,1:,:]#
        # convout = torch.cat([sen_1, sen_11, sen_2, sen_22], dim=1)
        # return self.linear( convout )

        #方案二:
        # sen_1 = inputs1[:,0,:]
        # sen_2 = inputs2[:,0,:]
        # sen_dot = sen_1 * sen_2
        # sen_error = sen_1 - sen_2
        # out = torch.cat([sen_1, sen_2, sen_dot, sen_error], dim=1)
        # return self.linear( out )

        # 方案三:
        q_lens = torch.sum(sentence_mask, dim=1).type(torch.cuda.LongTensor)
        q_len_max = int(torch.max(q_lens, dim=0)[0].cpu().data.numpy())
        inputs1 = inputs1[:, 0:q_len_max, :]
        sentence_mask = sentence_mask[:, 0:q_len_max]

        d_lens = torch.sum(source_sentence_mask,
                           dim=1).type(torch.cuda.LongTensor)
        d_len_max = int(torch.max(d_lens, dim=0)[0].cpu().data.numpy())
        inputs2 = inputs2[:, 0:d_len_max, :]
        source_sentence_mask = source_sentence_mask[:, 0:d_len_max]
        inputs1 = self.dropout(inputs1)
        inputs2 = self.dropout(inputs2)

        sen_1 = inputs1[:, 0, :]
        sen_2 = inputs1[:, 0, :]
        sen_11, alpha1 = self.self_attention(inputs2[:, 1:, :], sen_1,
                                             source_sentence_mask[:, 1:])
        sen_22, alpha2 = self.self_attention(inputs1[:, 1:, :], sen_2,
                                             sentence_mask[:, 1:])
        dic = {}
        alpha1 = alpha1.squeeze(1).cpu().data.numpy()
        alpha2 = alpha2.cpu().data.numpy()
        dic["alpha1"] = alpha1
        dic["alpha2"] = alpha2
        print(dic)
        picklesave("../alpha.pkl", dic, " alpha")
        convout = torch.cat([sen_1, sen_11, sen_2, sen_22], dim=1)
        return self.linear(convout)