def update_parameters(): with ng.as_default(): gradients = pickle.loads(request.data) nu_feed = {} for x, grad_var in enumerate(grads): nu_feed[grad_var[0]] = gradients[x] if lock_acquired: lock.acquire_write() with glob_session.as_default(): try: glob_session.run(train_op, feed_dict=nu_feed) self.weights = tensorflow_get_weights( trainable_variables) except: error_cnt = cont.next() if error_cnt >= max_errors: raise Exception( "Too many failures during training") finally: if lock_acquired: lock.release() return 'completed'
def start_service(self, metagraph, optimizer): """ Asynchronous flask service. This may be a bit confusing why the server starts here and not init. It is basically because this is ran in a separate process, and when python call fork, we want to fork from this thread and not the master thread """ app = Flask(__name__) self.app = app max_errors = self.iters lock = RWLock() server = tf.train.Server.create_local_server() ng = tf.Graph() with ng.as_default(): tf.train.import_meta_graph(metagraph) loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0] trainable_variables = tf.trainable_variables() grads = tf.gradients(loss_variable, trainable_variables) grads = list(zip(grads, trainable_variables)) train_op = optimizer.apply_gradients(grads) init = tf.global_variables_initializer() glob_session = tf.Session(server.target, graph=ng) with ng.as_default(): with glob_session.as_default(): glob_session.run(init) self.weights = tensorflow_get_weights(trainable_variables) cont = itertools.count() lock_acquired = self.acquire_lock @app.route('/') def home(): return 'Lifeomic' @app.route('/parameters', methods=['GET']) def get_parameters(): if lock_acquired: lock.acquire_read() vs = pickle.dumps(self.weights) if lock_acquired: lock.release() return vs @app.route('/update', methods=['POST']) def update_parameters(): with ng.as_default(): gradients = pickle.loads(request.data) nu_feed = {} for x, grad_var in enumerate(grads): nu_feed[grad_var[0]] = gradients[x] if lock_acquired: lock.acquire_write() with glob_session.as_default(): try: glob_session.run(train_op, feed_dict=nu_feed) self.weights = tensorflow_get_weights(trainable_variables) except: error_cnt = cont.next() if error_cnt >= max_errors: raise Exception("Too many failures during training") finally: if lock_acquired: lock.release() return 'completed' self.app.run(host='0.0.0.0', use_reloader=False, threaded=True, port=5000)