from single_stage_mr.models.feature_extracter import PoolVgg
from single_stage_mr.utils.vgg_features import get_features
import sys
import logging
import os
import pyicc
import json

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                    level=logging.INFO)

if __name__ == "__main__":
    path = sys.argv[1]
    os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2]
    logging.info("get data set")
    inter_set, intra_set = get_data_set(path=path, resize_length=80)
    # vgg特征提取器
    vgg_feature_selector = PoolVgg().cuda()
    inter_feat_df1, inter_feat_df2, intra_feat_df1, intra_feat_df2 = get_features(
        vgg_model=vgg_feature_selector,
        inter_set=inter_set,
        intra_set=intra_set,
        batchsize=16)
    print(inter_feat_df1.shape)
    print(inter_feat_df2.shape)
    feat_ind1 = pyicc._icc([inter_feat_df1, inter_feat_df2], "icc3", 0.75)
    print(len(feat_ind1))
    feat_ind2 = pyicc._icc([intra_feat_df1, intra_feat_df2], "icc3", 0.75)
    print(len(feat_ind2))
    inds = list(set(feat_ind1).intersection(set(feat_ind2)))
    print(len(inds))
    FP = neg_num - TN
    FN = pos_num - TP
    acc = (TP + TN) / (FP + FN + TP + TN)
    # ppv = TP / (TP + FP + 0.001)
    # npv = TN / (TN + FN + 0.001)
    # logging.info("cutoff:%f" % cutoff)
    # logging.info("tp:%d tn:%d fp:%d fn:%d" % (TP, TN, FP, FN))
    # logging.info('SEN:%f SPE:%f PPV:%f NPV:%f Acc:%f' % (sen, spe, ppv, npv, acc))
    return acc, TP, TN, FP, FN


if __name__ == "__main__":
    path = sys.argv[1]
    os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2]
    logging.info("get data set")
    train_neg_set, train_pos_set, test_pos_set, test_neg_set = get_data_set(
        path=path, resize_length=80)
    f = open(sys.argv[3])
    d = json.load(f)
    icc_inds = d["feature_index"]
    f.close()
    # 分数据集
    total_train_set = train_neg_set + train_pos_set
    total_test_set = test_neg_set + test_pos_set
    # 负样本的数量
    length = len(train_neg_set)
    logging.info("get data set completed")
    logging.info("get feature and cv")
    # vgg特征提取器
    vgg_feature_selector = PoolVgg().cuda()
    # 混合mr
    get_vgg_feature(vgg_model=vgg_feature_selector,