-
Notifications
You must be signed in to change notification settings - Fork 0
/
cone_search.py
161 lines (152 loc) · 6.85 KB
/
cone_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import boto3
import argparse
import time
from uuid import uuid4
import numpy as np
from numpy import math
import pandas as pd
import multiprocessing as mp
def get_alpha(radius, dec):
if abs(dec) + radius > 89.9:
return 180
return math.degrees(abs(math.atan(
math.sin(math.radians(radius)) /
np.sqrt(abs(math.cos(math.radians(dec - radius)) * math.cos(math.radians(dec + radius))
))
)))
def cone_search(ra, dec, radius, time_start, time_end, flag,
aws_profile, aws_region, s3_output_location, local_output_location,
athena_database, athena_workgroup, query_life=10, wait_time=0.1, single_query=True):
def query_athena(query, query_args):
query = query.format(*query_args)
response = athena_client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': athena_database
},
ResultConfiguration={
'OutputLocation': s3_output_location
},
WorkGroup=athena_workgroup
)
print('Query submitted:\n{}\n'.format(query))
rsp = athena_client.get_query_execution(QueryExecutionId=response['QueryExecutionId'])
succeeded_query = True if rsp['QueryExecution']['Status']['State'] == 'SUCCEEDED' else False
num_sec_query_has_been_running = 0
# check to see if the query has succeeded
while not succeeded_query:
if num_sec_query_has_been_running >= query_life:
print('QUERY CANCELLED: Query {} has been running for ~{} seconds.'.format(response['QueryExecutionId'],
num_sec_query_has_been_running))
_ = athena_client.stop_query_execution(QueryExecutionId=response['QueryExecutionId'])
return None
if num_sec_query_has_been_running % 60 == 0 and num_sec_query_has_been_running:
duration = int(num_sec_query_has_been_running/60)
word = 'minutes' if duration > 1 else 'minute'
print('...Query has been running for ~{} {}.'.format(duration, word))
# wait until query has succeeded to start the next query
if num_sec_query_has_been_running + wait_time > query_life:
sleep_time = query_life - num_sec_query_has_been_running
else:
sleep_time = wait_time
time.sleep(sleep_time)
num_sec_query_has_been_running += sleep_time
rsp = athena_client.get_query_execution(QueryExecutionId=response['QueryExecutionId'])
succeeded_query = True if rsp['QueryExecution']['Status']['State'] == 'SUCCEEDED' else False
return response['QueryExecutionId']
sess = boto3.Session(profile_name=aws_profile,
region_name=aws_region)
athena_client = sess.client('athena')
s3_client = sess.client('s3')
s3_output_path = s3_output_location.replace('s3://', '').split('/')
bucket = s3_output_path[0]
additional_s3_path = s3_output_location.replace('s3://{}/'.format(bucket), '')
queries = {
'single': '''
SELECT *
FROM gPhoton_partitioned
WHERE zoneID BETWEEN {} AND {}
AND dec BETWEEN {} AND {}
AND ra BETWEEN {} AND {}
AND ({}*cx + {}*cy + {}*cz) > {}
AND time >= {} AND time < {}
AND flag = {};
''',
'multiple': '''
SELECT *
FROM gPhoton_partitioned
WHERE zoneID = {}
AND dec BETWEEN {} AND {}
AND ra BETWEEN {} AND {}
AND ({}*cx + {}*cy + {}*cz) > {}
AND time >= {} AND time < {}
AND flag = {};
'''
}
cx = math.cos(math.radians(dec)) * math.cos(math.radians(ra))
cy = math.cos(math.radians(dec)) * math.sin(math.radians(ra))
cz = math.sin(math.radians(dec))
alpha = get_alpha(radius, dec)
if (ra - alpha) < 0:
ra = ra + 360
zoneHeight = 30.0/3600.0
min_zoneid = int(np.floor((dec - radius + 90.0) / zoneHeight))
max_zoneid = int(np.floor((dec + radius + 90.0) / zoneHeight))
query_args_collection = {
'non-conditional': [
dec - radius, dec + radius,
ra - alpha, ra + alpha,
cx, cy, cz, math.cos(math.radians(radius)),
time_start, time_end,
flag
],
'conditional': [
dec - radius, dec + radius,
0, ra - 360 + alpha,
cx, cy, cz, math.cos(math.radians(radius)),
time_start, time_end,
flag
]
}
query_collection = ''
query_argument_collection = []
for zoneid in range(min_zoneid, max_zoneid+1):
query = queries['single'] if single_query else queries['multiple']
zone_args = [min_zoneid, max_zoneid] if single_query else [zoneid]
query_args = zone_args + query_args_collection['non-conditional']
if (ra + alpha) > 360:
query = query.replace(';', '') + '\n UNION ALL\n' + query
additional_args = [min_zoneid, max_zoneid] if single_query else [zoneid]
query_args = query_args + additional_args + query_args_collection['conditional']
temp_query = query.replace(';', '') + '\n UNION ALL\n' if not single_query and zoneid != max_zoneid else query
query_collection += temp_query
query_argument_collection.extend(query_args)
if single_query:
break
start_time = time.time()
execution_id = query_athena(query_collection, query_argument_collection)
print('Time taken to query: ~{:.4f} seconds'.format(time.time()-start_time))
# get single CSV or accumulate the CSVs from the different SELECT statements
dfs = []
download_paths = []
if execution_id is not None:
path_to_csv = os.path.join(additional_s3_path, execution_id + '.csv')
download_path = os.path.join(local_output_location, execution_id + '.csv')
start_time = time.time()
s3_client.download_file(bucket,
path_to_csv,
download_path)
print('Time taken to download: ~{:.4f} seconds'.format(time.time()-start_time))
dfs.append(pd.read_csv(download_path, engine='python'))
download_paths.append(download_path)
if len(dfs):
df = pd.concat(dfs)
output_location = os.path.join(local_output_location, str(uuid4()) + '.csv')
df.to_csv(output_location, index=False)
print('\nData written to {}\n'.format(output_location))
for download_path in download_paths:
os.remove(download_path)
print(df.head())
else:
print('No CSVs were found.')