forked from jhlau/acceptability_prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
calc_correlation.py
74 lines (59 loc) · 1.62 KB
/
calc_correlation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Stdin: N/A
Stdout: N/A
Author: Jey Han Lau
Date: Jul 14
"""
import argparse
import sys
from scipy.stats.mstats import pearsonr, spearmanr
#parser arguments
desc = "Calculates correlation of model scoring functions and gug ratings."
parser = argparse.ArgumentParser(description=desc)
#####################
#positional argument#
#####################
parser.add_argument("test_csv", help="csv file that contain the scoring functions")
parser.add_argument("rating_file", help="file that contains human ratings")
###################
#optional argument#
###################
args = parser.parse_args()
#parameters
debug = False
###########
#functions#
###########
######
#main#
######
#process the rating file
ratings = []
for line in open(args.rating_file):
ratings.append(float(line.strip()))
if debug:
print "Ratings", len(ratings), "=", ratings[:10]
#process the test.csv file
metrics = []
probs = []
for line_id, line in enumerate(open(args.test_csv)):
data = line.strip().split(",")
if line_id == 0:
metrics = data[3:]
if debug:
print "\nmetrics =", metrics
else:
for i, score in enumerate(data[3:]):
if len(probs) == i:
probs.append([])
if score == "":
score = 0
probs[i].append(float(score))
#print "\n".join(metrics), "\n"
print "METRICS\tCORRELATION"
for i, prob in enumerate(probs):
if debug:
print "\nmetric =", metrics[i]
print "\tprob", len(prob), "=", prob[:5]
corr = pearsonr(ratings, prob)[0]
print metrics[i] + "\t" + str(corr)