forked from mshakya/GTA_Hunter
/
GTA_Hunter.py
195 lines (176 loc) · 6.26 KB
/
GTA_Hunter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
File name: GTA_Hunter.py
Date created: 10/22/2015
Date last modified: 05/31/2016
Python version: 3.5.1
Description: Combines Loader.py,
Profile.py, SVM.py, Features.py,
Weight.py, and Filter.py to be able
to characterize and classify two different
types of genes (GTA and phage).
"""
###############
### IMPORTS ###
###############
from Loader import Loader
from Weight import Weight
from Feature import Feature
from SVM import SVM
import argparse
import numpy as np
import time
############
### META ###
############
__author__ = "Taylor Neely"
__version__ = "1.0.0"
__email__ = "tneely@dartmouth.edu"
#################
### CONSTANTS ###
#################
NREPS = 10 # for xval in SVM
PSE_WEIGHT = 0.05 # for pseaac feature
############
### CODE ###
############
def get_args():
# Init parser
parser = argparse.ArgumentParser(description="Gene Classification Using SVM.")
### Define Args ###
# Main
parser.add_argument("-g", "--GTA", type=str, nargs=1,
dest="gta", required=True,
help="The .faa or .fna training file for GTA genes.")
parser.add_argument("-v", "--virus", type=str, nargs=1,
dest="virus", required=True,
help="The .faa or .fna training file for viral genes.")
parser.add_argument("-q", "--queries", type=str, nargs=1,
dest="queries", required=False,
help="The .faa or .fna query file to be classified.")
parser.add_argument("-k", "--kmer", type=int, nargs="?",
dest="kmer", required=False, const=4, default=None,
help="The kmer size needed for feature generation (default=4).")
parser.add_argument("-p", "--pseaac", nargs="?", type=int,
dest="pseaac", required=False, const=3, default=None,
help="Expand feature set to include pseudo amino acid composition. Specify lamba (default=3). Weight = 0.05.")
parser.add_argument("-y", "--physico", action="store_true",
dest="physico", required=False,
help="Expand feature set to include physicochemical composition.")
parser.add_argument("-m", "--min", action="store_true",
dest="mini", required=False,
help="Print bare minimum results.")
# Weight
parser.add_argument("-w", "--weight", type=str, nargs=2,
dest="weight", required=False,
help="Allows for weighting of training set. Will need to specify the two pairwise distance files needed for weighting (GTA first, then virus).")
parser.add_argument("-t", "--cluster_type", type=str, nargs=1,
dest="cluster_type", required=False, default=['farthest'],
help="Specify 'farthest' or 'nearest' neighbors clustering (default='farthest').")
parser.add_argument("-d", "--dist", type=float, nargs=1,
dest="dist", required=False, default=[0.01],
help="Specify the cutoff distance for clustering in the weighting scheme (default=0.01).")
# SVM
parser.add_argument("-c", "--soft_margin", type=float, nargs=1,
dest="c", required=False, default=[1.0],
help="The soft margin for the SVM (default=1.0).")
parser.add_argument("-x", "--xval", type=int, nargs="?",
dest="xval", required=False, const=5, default=None,
help="Performs cross validation of training set. Specify folds over 10 repetitions (default=5).")
parser.add_argument("-e", "--kernel", nargs=2,
dest="kernel", required=False, default=["linear",0],
help="Specify kernel to be used and sigma if applicable (i.e. gaussian) (default='linear', 0).")
parser.add_argument("-s", "--svs", action="store_true",
dest="svs", required=False,
help="Show support vectors.")
return parser
if __name__ == '__main__':
start = time.time()
# Get args
parser = get_args()
args = parser.parse_args()
# Print detail
mini = args.mini
### Load training set and make features ###
gta_file = args.gta[0]
virus_file = args.virus[0]
# Load profiles
gta_profs = Loader.load(gta_file, "GTA")
viral_profs = Loader.load(virus_file, "virus")
# Make features
feats = Feature(gta_profs.profiles + viral_profs.profiles)
if args.kmer == None:
kmer_size = args.kmer
feats.make_kmer_dict(kmer_size)
feats.kmer_feat()
if args.pseaac == None:
feats.pseaac(lam=int(args.pseaac), weight=PSE_WEIGHT)
if args.physico:
feats.physicochem()
if args.kmer == None and args.pseaac == None and not args.physico:
print("You must specify at least one feature type (-k, -p, -y).")
else:
# Weight if needed
if args.weight:
# Get distance threshold
d = args.dist[0]
# Get cluster type
cluster_type = args.cluster_type[0]
# Weight GTA
pairwiseGTA = Weight.load(args.weight[0])
GTA_weight = Weight(gta_profs, pairwiseGTA)
GTA_clusters = GTA_weight.cluster(cluster_type, d)
GTA_weight.weight(GTA_clusters)
# Weight Virus
pairwiseViral = Weight.load(args.weight[1])
virus_weight = Weight(viral_profs, pairwiseViral)
virus_clusters = virus_weight.cluster(cluster_type, d)
virus_weight.weight(virus_clusters)
# Create SVM
c = args.c[0]
kernel = args.kernel[0]
kernel_var = float(args.kernel[1])
svm = SVM(gta_profs, viral_profs, c, kernel, kernel_var)
# Print support vectors
if args.svs:
svm.show_svs()
# Xval
if args.xval:
nfolds = args.xval
if args.weight:
result = svm.xval(nfolds, NREPS, pairwiseGTA, pairwiseViral, cluster_type, d)
else:
result = svm.xval(nfolds, NREPS)
if mini:
print("GTA Correct\tViral Correct")
print("%.2f\t%.2f" % (result[0], result[1]))
else:
print("We correctly classified (on average) %.2f/%d GTA and %.2f/%d Viral genes."
% (result[0], len(gta_profs), result[1], len(viral_profs)))
else: # Otherwise classify test set
# Make sure queries set
if args.queries == None:
print("The query file was not specified. Please declare queries using -q.")
else: # All good
# Load test set
test_file = args.queries[0]
test_profs = Loader.load(test_file)
# Make features
if args.kmer:
feats.kmer_feat(test_profs)
if args.pseaac:
feats.pseaac(lam=int(args.pseaac), weight=PSE_WEIGHT, profiles=test_profs)
if args.physico:
feats.physicochem(profiles=test_profs)
# Classify
svm.predict(test_profs)
# Print results
if mini:
print("Gene\t\tClass")
for profile in test_profs:
print(">%s\t%s" % (profile.name, profile.label))
else:
print("%-*s%-*s%-*s" % (55, "Gene", 10, "Score", 5, "Classification"))
for profile in test_profs:
print(">%-*s%-*f%-*s" % (55, profile.org_name, 10, profile.score, 5, profile.label))
end = time.time()
print(end - start)