def computeNDCG_GR(base_path,result_path):
    total = 0
    num = 0
    idx = 0
    low = 0
    base_list = parse_baseline(base_path)
    session_list = session_GR(result_path)
    for session in session_list:
        if base_list[idx]['session_id'] == session[0]:
            dict = {}
            lists = []
            for url1 in base_list[idx]['url_list']:
                dict[int(url1[1])] = url1[0]
            for k in range(len(session[1])):
                score = dict[int(session[1][k][1])]
                # if score == 0:
                #     score =0.4/math.log(k+2,2)
                lists.append( (score,session[1][k][1]) )
            # a,b = computeNDCG(lists),computeNDCG(base_list[idx]['url_list'])
            if NDCG.computeNDCG(base_list[idx]['url_list']) > NDCG.computeNDCG(lists):
                # if (computeNDCG(base_list[idx]['url_list']) - computeNDCG(lists))>0.4:
                    # print('new: ',computeNDCG(lists),session[1])
                    # print('old: ',computeNDCG(base_list[idx]['url_list']),base_list[idx]['url_list'])
                low+=1
            total += NDCG.computeNDCG(lists)
            num+=1
        else:
            print('not find corresponding session')
            break
        idx+=1
    print ('Global Reward:',total/num)
    return total/num
def compute_total_reward(test_path,result_path,base_path,parameter):
    gr = session_GR(result_path,parameter)
    lr = session_LR(test_path,base_path,parameter)
    ndcg = 0
    base_ndcg = 0#test
    num = 0
    with open(base_path,'r',newline='') as base:
        for line in base:
            # num+=1

            flag = False #test

            baseline = json.loads(line)
            base_url_list = baseline['url_list']
            # print(base_url_list)
            global_reward = gr.__next__()
            local_reward = lr.__next__()
            new_url_list = []
            if int(global_reward[0]) == int(local_reward[0]):
                gr_list = global_reward[1]
                lr_list = local_reward[1]

                for lr_item in lr_list: #test
                    if lr_item[0] > 0:
                        flag = True
                        break

                for i in range(len(gr_list)):
                    if int(gr_list[i][1]) == int(lr_list[i][1]):
                        # score = (gr_list[i][0])*0.7+lr_list[i][0]*0.3*0.51
                        score = (gr_list[i][0])*parameter['bs']\
                                +lr_list[i][0]*(1-parameter['bs'])\
                                 *((global_reward[2]*parameter['s_rel'])+((1-global_reward[2])*parameter['s_irel']))
                        new_url_list.append([round(score,6),int(lr_list[i][1])])
                new_url_list.sort(reverse = True)
                # print(new_url_list)
            final_url_list = recover_rel(base_url_list,new_url_list)
            # if NDCG.computeNDCG(final_url_list) < 1:

            # if  flag:#test
                # print('gr',global_reward)
                # print('lr',local_reward)
                # print('final',final_url_list)
                # print(base_url_list)
                # num+=1
                # ndcg+=NDCG.computeNDCG(final_url_list)
                # base_ndcg+=NDCG.computeNDCG(base_url_list)

            num+=1
            ndcg+=NDCG.computeNDCG(final_url_list)


            # if NDCG.computeNDCG(final_url_list) < 0.5:
            #     print('gr',global_reward)
            #     print('lr',local_reward)
            #     print('final',final_url_list)
            #     print(base_url_list)

    print('Total Reward:',ndcg/num)#test
    return (ndcg/num)
def sample(test_path,base_path):
    s = parse_test_session(test_path)
    a,b=0,0
    with open(base_path,'r',newline='') as base:
        for line in base:
            baseline = json.loads(line)
            base_url_list = baseline['url_list']
            base_serp = baseline['serp']
            base_dict = make_dict(base_url_list)
            session = s.__next__()
            temp_url_list = []
            if session['session_id'] == baseline['session_id']:
                dict = collections.defaultdict(int)
                for query in session['query_list']:
                    q_url_list = query['url_list']
                    serp_now = query['serp']
                    if serp_now != base_serp:
                        for i in range(len(q_url_list)):
                            if q_url_list[i][0] > 1:
                                score = ((q_url_list[i][0])-1)/math.log(i+2,2)
                                score = (score*math.pow(0.8,(base_serp-serp_now)))/2
                                if score > dict[q_url_list[i][1]]:
                                    dict[q_url_list[i][1]] = score
                for idx in range(len(base_url_list)):
                    base = 1/math.log(idx+2,2)
                    temp_url_list.append( [round(base+dict[base_url_list[idx][1]],2),base_url_list[idx][1]] )
                temp_url_list.sort(reverse=True)
                for url in temp_url_list:
                    url[0] = base_dict[url[1]]
                new_ndcg = NDCG.computeNDCG(temp_url_list)
                old_ndcg = NDCG.computeNDCG(base_url_list)
                if new_ndcg > old_ndcg:
                    # print(session)
                    # print('new:',new_ndcg,temp_url_list)
                    # print('old:',old_ndcg,base_url_list)
                    a+=1
                elif new_ndcg < old_ndcg:
                    print(session)
                    print(temp_url_list)
                    b+=1
            else:
                print('error,not found corresponding base or session')
                break
    print('a',a,'b',b)
def test_local_method_proportion(test_path,base_path):
    s = parse_test_session(test_path)
    a,b = 0,0
    b_ndcg = 0
    a_ndcg = 0
    total = 0
    with open(base_path,'r',newline='') as base:
        for line in base:
            total+=1
            baseline = json.loads(line)
            base_url_list = baseline['url_list']
            base_serp = baseline['serp']
            session = s.__next__()
            test_dict = collections.defaultdict(int)
            if session['session_id'] == baseline['session_id']:
                for query in session['query_list']:
                    q_url_list = query['url_list']
                    serp_now = query['serp']
                    if serp_now != base_serp:
                        for i in range(len(q_url_list)):
                            if q_url_list[i][0] > 1:
                                test_dict[q_url_list[i][1]]=q_url_list[i][0]
                for url in base_url_list:
                    if test_dict[url[1]] > 1 and url[0] == 1:
                        a+=1
                        a_ndcg+=NDCG.computeNDCG(base_url_list)
                        print(session)
                        # break
                    elif url[0]>1 and test_dict[url[1]] > 1:
                        b+=1
                        b_ndcg+=NDCG.computeNDCG(base_url_list)
                        # break
            else:
                print('error,not found corresponding base or session')
                break
    print('a',a,a_ndcg/a,'b',b,b_ndcg/b,'total',total)
def ndcg_distribute(path):
    ndcg_list = [[i/10,0] for i in range(1,11)]
    with open(path,'r') as baseline:
        for line in baseline:
            url_list = json.loads(line)['url_list']
            ndcg = NDCG.computeNDCG(url_list)
            for item in ndcg_list:
                if ndcg < item[0]:
                    item[1]+=1
                    break
    print(ndcg_list)
    x = [i[0]for i in ndcg_list]
    y = [i[1]for i in ndcg_list]
    plt.bar(x, y,-0.1,color='y', edgecolor='g', linewidth=1, align='edge')
    plt.show()
def computeNDCG_LR(test_path,base_path):
    s = parse_test_session(test_path)
    ndcg = 0
    num = 0
    with open(base_path,'r',newline='') as base:
        for line in base:
            baseline = json.loads(line)
            base_url_list = baseline['url_list']
            base_serp = baseline['serp']
            base_dict = make_dict(base_url_list)
            session = s.__next__()
            temp_url_list = []
            if session['session_id'] == baseline['session_id']:
                url_dict = collections.defaultdict(int)
                domain_dict = collections.defaultdict(int)
                for query in session['query_list']:
                    q_url_list = query['url_list']
                    serp_now = query['serp']
                    if serp_now != base_serp:
                        for i in range(len(q_url_list)):
                            rel,url,domain = q_url_list[i][0],q_url_list[i][1],q_url_list[i][2]
                            if rel > 1:
                                url_score = compute_reward(rel,i,0.6,base_serp-serp_now)
                                if url_score > url_dict[url]:
                                    url_dict[url] = url_score
                                domain_score = compute_reward(rel,i,0.6,base_serp-serp_now)
                                if domain_score > domain_dict[domain]:
                                    domain_dict[domain] = domain_score
                for idx in range(len(base_url_list)):
                    base = 1/math.log(idx+2,2)
                    if idx>=5:
                        base = base/(idx)
                    # base = base/(idx+1)
                    d_reward = domain_dict[base_url_list[idx][2]]
                    u_reward = url_dict[base_url_list[idx][1]]
                    if d_reward > 0 and u_reward == 0:
                        reward = d_reward*0.6
                    elif u_reward > 0:
                        reward = u_reward*0.1
                    else:
                        reward = 0
                    temp_url_list.append( [base+reward/2,base_url_list[idx][1]] )

                temp_url_list.sort(reverse=True)

                for url in temp_url_list:
                    url[0] = base_dict[url[1]]
                # if int(session['session_id']) == 27558:
                #     print(session)
                #     print('new:',temp_url_list)
                #     print('old:',base_url_list)
                ndcg+=NDCG.computeNDCG(temp_url_list)
                num+=1
            else:
                print('error,not found corresponding base or session')
                break
    print('Local Reward:',ndcg/num)
    return  ndcg/num



# computeNDCG_LR(gl.test_sample_url_domain,gl.baseline_sample_url_domain)
# test_local_method_proportion(gl.test_clean_sample,gl.baseline_clean_sample)