forked from fizyr/keras-retinanet
/
app.py
151 lines (111 loc) · 4.39 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import config_accessor as cfg
from flask import Flask, request, Response, jsonify
import traceback
import sys
import core
import misc
import logging
from model_encoder import ResponseEncoder
from models import Content
from services import classify_content, add_to_blacklist, remove_from_blacklist, reset_blacklist
logger = logging.getLogger('celum.app')
app = Flask(__name__)
default_error_message = 'Server endpoint not responding! Please check your request or try again later.'
class InvalidUsage(Exception):
status_code = 400
def __init__(self, message, status_code=None, payload=None):
Exception.__init__(self)
self.message = message
if status_code is not None:
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv['exception'] = self.message
return rv
@app.errorhandler(InvalidUsage)
def handle_invalid_usage(error):
response = jsonify(error.to_dict())
response.status_code = error.status_code
return response
def handle_request(content):
try:
classification_results = classify_content(content)
encoder = ResponseEncoder(classification_results)
return encoder.to_json()
except:
err = traceback.format_exc()
logger.error(err)
raise InvalidUsage('Error while processing request! Please try again later...')
def parse_post_req_content(insert=False):
try:
json_content = request.get_json()
content = Content(json_content)
content.insert = insert
return content
except:
err = traceback.format_exc()
logger.error(err)
raise InvalidUsage('Could not parse request! Please check ids and request format.')
@app.route('/services/v1/insert', methods=['POST'])
def insert_assets():
content = parse_post_req_content(insert=True)
max_requests = cfg.resolve_int(cfg.CLASSIFICATION, cfg.max_assets_per_request)
if len(content.assets) > max_requests:
raise InvalidUsage('Exceeded maximum number of assets ({}) per request!'.format(max_requests))
return handle_request(content)
@app.route('/services/v1/classify', methods=['POST'])
def classify_assets():
content = parse_post_req_content()
max_requests = cfg.resolve_int(cfg.CLASSIFICATION, cfg.max_assets_per_request)
if len(content.assets) > max_requests:
raise InvalidUsage('Exceeded maximum number of assets ({}) per request!'.format(max_requests))
return handle_request(content)
@app.route('/services/v1/classify', methods=['GET'])
@misc.jsonp
def classify():
id_ = request.args.get('id')
url = request.args.get('url')
if not id_ or not url:
raise InvalidUsage('Missing id or url parameter!')
content = misc.classify_get_req_to_content(id_, url)
return handle_request(content)
@app.route('/services/v1/blacklist/<string:asset_id>', methods=['DELETE'])
def blacklist_id(asset_id):
ret = add_to_blacklist(asset_id)
return Response(status=200 if ret == 0 else 404)
@app.route('/services/v1/blacklist/undo/<string:asset_id>', methods=['GET'])
def undo_blacklist_id(asset_id):
ret = remove_from_blacklist(asset_id)
return Response(status=200 if ret == 0 else 404)
@app.route('/services/v1/blacklist/init', methods=['GET'])
def init_blacklist():
core.initialize_blacklist()
return Response(status=200)
@app.route('/services/v1/blacklist/reset', methods=['DELETE'])
def blacklist_reset():
ret = reset_blacklist()
return Response(status=200 if ret == 0 else 404)
@app.route('/services/v1/shutdown', methods=['GET'])
def shutdown_hook():
core.trigger_backup()
sys.exit()
@app.route('/services/v1/index/init', methods=['GET'])
def init_similarity_index():
core.initialize_elastic_search()
return Response(status=200)
@app.before_first_request
def initialize():
core.initialize_similarity_index()
core.initialize_blacklist()
core.initialize_elastic_search()
core.initialize_retinanet()
core.initialize_extraction_model()
core.initialize_cron_job()
if __name__ == '__main__':
core.initialize_logging()
logger.info('Server app started!')
app.run(host=cfg.resolve(cfg.RETINANET_SERVER, cfg.host),
port=cfg.resolve_int(cfg.RETINANET_SERVER, cfg.port),
debug=cfg.resolve_bool(cfg.RETINANET_SERVER, cfg.debug),
threaded=cfg.resolve_bool(cfg.RETINANET_SERVER, cfg.threaded))