def distance_compare():
    c_Data = []
    cluster_dir = 'cluster_data\\180901-181201\\'
    cluster_file = 'z_scored_180901-181201.csv'
    predict_test = 'cluster_data\\181201-190201\\'
    predict_test_file = 'z_scored_181201-190201.csv'
    with open(cluster_dir + cluster_file, 'r') as f:
        lines = csv.reader(f)
        for line in lines:
            c_Data.append([float(t) for t in line])
    all_Data = []  # 这是预测的
    #with open(predict_dir + predict_file, 'r') as f:
    with open(predict_test + predict_test_file, 'r') as f:
        lines = csv.reader(f)
        for stock in lines:
            all_Data.append([float(t) for t in stock])
    dis_1 = []
    dis_2 = []
    min_1 = 1e20
    max_1 = 0
    for i in tqdm(range(len(all_Data))):
        for j in range(i + 1, len(all_Data)):
            #dis_1.append(cDTW(c_Data[i],c_Data[j]))
            #dis_2.append(cDTW(all_Data[i],all_Data[j]))
            dis_1.append(cDTW(c_Data[i], c_Data[j]))
            dis_2.append(cDTW(all_Data[i], all_Data[j]))
            if (dis_1[-1] < min_1):
                min_1 = dis_1[-1]
            if (dis_1[-1] > max_1):
                max_1 = dis_1[-1]

    #plt.scatter(dis_1,dis_2,s=5)
    #plt.show()
    dis_11 = []
    dis_21 = []
    #总的距离分布直方图,在未来区间上
    #plt.hist(dis_2,bins=80)
    #plt.show()
    #print("total mean: ",np.mean(dis_2))
    delta = (max_1 - min_1) / 8
    iter = min_1
    while (iter < max_1):
        dis_21 = []
        for i, t in enumerate(dis_1):
            if t <= iter + delta and t > iter:
                dis_21.append(dis_2[i])
        plt.hist(dis_21, bins=80)
        plt.show()
        print(iter, "-", iter + delta, " mean: ", np.mean(dis_21))
        iter += delta
def AvgDis(Data):
    data_len = len(Data)
    ret = 0
    for i, s1 in enumerate(Data):
        for j, s2 in enumerate(Data[i + 1:]):
            ret += cDTW(s1, s2)
    ret = ret * 2 / (data_len * (data_len - 1))
    return ret
def refine_ans():
    c_Data = []
    cluster_dir = 'cluster_ans\\for_train\week9_zscored\\'
    cluster_file = 'wk9_zscored190211-190410.csv'
    with open(cluster_dir + cluster_file, 'r') as f:
        lines = csv.reader(f)
        for line in lines:
            c_Data.append([float(t) for t in line])
    cent = []
    with open(clustering_ans_dir + cent_file, 'r') as f:
        cents = csv.reader(f)
        for c in cents:
            cent.append([float(tt) for tt in c])
    mem_1 = []
    with open(clustering_ans_dir + mem_file, 'r') as f:
        mems = csv.reader(f)
        for line in mems:
            mem_1.append(int(line[1]))
    global num_cluster
    x_idx = [i for i in range(1, len(c_Data[0]) + 1)]
    for i in range(num_cluster):
        data_i = []
        for j in range(len(mem_1)):
            if mem_1[j] == i:
                data_i.append(c_Data[j])
        # Cn2
        pair_mindis = [[1], [1]]
        mindis = 1e20
        pair_maxdis = [[1], [1]]
        maxdis = 0
        for t1, seq1 in enumerate(data_i):
            for seq2 in data_i[t1 + 1:]:
                thisdis = cDTW(seq1, seq2)
                if thisdis < mindis:
                    mindis = thisdis
                    pair_mindis = [seq1, seq2]
                if thisdis > maxdis:
                    maxdis = thisdis
                    pair_maxdis = [seq1, seq2]
        plt.figure()
        plt.plot(x_idx, pair_mindis[0])
        plt.plot(x_idx, pair_mindis[1])
        plt.show()
        plt.plot(x_idx, pair_maxdis[0])
        plt.plot(x_idx, pair_maxdis[1])
        plt.show()
def compute_DBI_obv():
    '''
    for clustering ans which has been refined
    :return: DBI
    '''
    mem_file_1 = 'cluster_ans/for_train/' + train_begin_date + '-' + train_end_date + '/' + 'obvious_mem.csv'
    data_file_1 = 'cluster_ans/for_train/' + train_begin_date + '-' + train_end_date + '/' + 'obvious_stock.csv'
    data_file_predict = 'cluster_data/' + predict_begin_date + '-' + predict_end_date + '/' \
    + 'z_scored_' +  predict_begin_date + '-' + predict_end_date+'.csv'

    all_Data = []
    with open(data_file_1, 'r') as f:
        lines = csv.reader(f)
        for stock in lines:
            all_Data.append([float(t) for t in stock])
    mem = []
    temp_dic = {}
    with open(mem_file_1, 'r') as f:
        mems = csv.reader(f)
        for line in mems:
            mem.append(int(line[1]))
            temp_dic[mem_index_dict[line[0]]] = 1
    pre_Data = []
    with open(data_file_predict, 'r') as f:
        pres = csv.reader(f)
        for ii, line in enumerate(pres):
            try:
                a = temp_dic[ii]
                pre_Data.append([float(t) for t in line])
            except:
                continue
    assert (len(pre_Data) == len(all_Data))
    global num_cluster
    DBI = 0
    for i in tqdm(range(num_cluster)):
        this_cluster = []
        for stock, stock_type in enumerate(mem):
            if stock_type == i:
                this_cluster.append(pre_Data[stock])  #c_data or all_data
        avg_i = AvgDis(this_cluster)

        cent_pi, _ = DBA_iteration(this_cluster[0], np.array(this_cluster))

        #print(i,": ",avg_i)
        tmp_max = -1
        for j in range(num_cluster):
            if j == i:
                continue
            this_j_cluster = []
            for stock, stock_type in enumerate(mem):
                if stock_type == j:
                    this_j_cluster.append(pre_Data[stock])
            avg_j = AvgDis(this_j_cluster)
            cent_pj, _ = DBA_iteration(this_j_cluster[0],
                                       np.array(this_j_cluster))
            d_cent = cDTW(cent_pi, cent_pj)

            if tmp_max < (avg_i + avg_j) / d_cent:
                tmp_max = (avg_i + avg_j) / d_cent
        DBI += tmp_max
    DBI = DBI / num_cluster
    return DBI
def compute_DBI():
    if_predict = 1
    c_Data = []
    cluster_dir = 'cluster_ans\\for_train\\' + train_begin_date + '-' + train_end_date + '\\'
    cluster_file = 'z_scored_' + train_begin_date + '-' + train_end_date + '.csv'
    with open(cluster_dir + cluster_file, 'r') as f:
        lines = csv.reader(f)
        for line in lines:
            c_Data.append([float(t) for t in line])

    #all_Data 是预测区间的
    all_Data = []
    with open(predict_dir + predict_file, 'r') as f:
        lines = csv.reader(f)
        for stock in lines:
            all_Data.append([float(t) for t in stock])

    cent = []
    with open(clustering_ans_dir + cent_file, 'r') as f:
        cents = csv.reader(f)
        for c in cents:
            cent.append([float(tt) for tt in c[1:]])

    mem = []
    with open(clustering_ans_dir + mem_file, 'r') as f:
        mems = csv.reader(f)
        for line in mems:
            mem.append(int(line[1]))

    global num_cluster
    DBI = 0
    for i in tqdm(range(num_cluster)):
        this_cluster = []
        for stock, stock_type in enumerate(mem):
            if stock_type == i:
                this_cluster.append(all_Data[stock])  #c_data or all_data
        avg_i = AvgDis(this_cluster)
        cent_pi = 1
        if if_predict:
            cent_pi, _ = DBA_iteration(this_cluster[0], np.array(this_cluster))
        else:
            cent_pi = cent[i]
        #print(i,": ",avg_i)
        tmp_max = -1
        for j in range(num_cluster):
            if j == i:
                continue
            this_j_cluster = []
            for stock, stock_type in enumerate(mem):
                if stock_type == j:
                    this_j_cluster.append(all_Data[stock])
            avg_j = AvgDis(this_j_cluster)
            d_cent = cDTW(cent[i], cent[j])
            cent_pj = 1
            if if_predict:
                cent_pj, _ = DBA_iteration(this_j_cluster[0],
                                           np.array(this_j_cluster))
            else:
                cent_pj = cent[j]
            d_cent = cDTW(cent_pi, cent_pj)

            if tmp_max < (avg_i + avg_j) / d_cent:
                tmp_max = (avg_i + avg_j) / d_cent
        DBI += tmp_max
    DBI = DBI / num_cluster
    return DBI
def numerical_compare():
    c_Data = []
    stock_idx = []
    cluster_dir = 'cluster_ans\\for_train\week9_zscored\\'
    cluster_file = 'wk9_zscored190211-190410.csv'
    with open(cluster_dir + cluster_file, 'r') as f:
        lines = csv.reader(f)
        for line in lines:
            c_Data.append([float(t) for t in line])
    with open(predict_dir + predict_idx_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            stock_idx.append(line.strip())

    all_Data = []  #这是预测的
    with open(predict_dir + predict_file, 'r') as f:
        lines = csv.reader(f)
        for stock in lines:
            all_Data.append([float(t) for t in stock])
    '''
    cent = []
    with open(clustering_ans_dir+cent_file,'r') as f:
        cents = csv.reader(f)
        for c in cents:
            cent.append([float(tt) for tt in c])
    '''
    mem_1 = []
    with open(clustering_ans_dir + mem_file, 'r') as f:
        mems = csv.reader(f)
        for line in mems:
            mem_1.append(int(line[1]))

    global num_cluster
    x_idx = [i for i in range(1, len(all_Data[0]) + 1)]
    x_idx_c = [i for i in range(1, len(c_Data[0]) + 1)]
    all_dis = []
    all_dis_predict = []
    for i in tqdm(range(len(c_Data))):
        for j in range(i + 1, len(c_Data)):
            all_dis.append(cDTW(c_Data[i], c_Data[j]))
            all_dis_predict.append(cDTW(all_Data[i], all_Data[j]))

# plt.hist(all_dis,bins=50)
    for i in range(num_cluster):
        dis_i = []
        dis_i_predict = []
        cs_idx = []
        # 统计一下子:一对股票在训练数据的DTW距离与在预测数据上的DTW距离的关系
        print(i)
        data_i = []
        c_data_i = []
        plt.hist(all_dis, bins=50, color="#FF0000", alpha=.9)
        for j in range(len(mem_1)):
            if mem_1[j] == i:
                cs_idx.append(j)
                data_i.append(all_Data[j])
                c_data_i.append(c_Data[j])
        for t1 in range(len(c_data_i)):
            for t2 in range(t1 + 1, len(c_data_i)):
                dis_i.append(cDTW(c_data_i[t1], c_data_i[t2]))
                dis_i_predict.append(cDTW(data_i[t1], data_i[t2]))
        plt.hist(dis_i, bins=50, color="#C1F320", alpha=.5)
        plt.show()
        #预测
        plt.hist(all_dis_predict, bins=50, color="#FF0000", alpha=.9)
        plt.hist(dis_i_predict, bins=50, color="#C1F320", alpha=.5)
        plt.show()
        '''