Ejemplo n.º 1
0
        def update_parameters():
            with ng.as_default():
                gradients = pickle.loads(request.data)
                nu_feed = {}
                for x, grad_var in enumerate(grads):
                    nu_feed[grad_var[0]] = gradients[x]

                if lock_acquired:
                    lock.acquire_write()

                with glob_session.as_default():
                    try:
                        glob_session.run(train_op, feed_dict=nu_feed)
                        self.weights = tensorflow_get_weights(
                            trainable_variables)
                    except:
                        error_cnt = cont.next()
                        if error_cnt >= max_errors:
                            raise Exception(
                                "Too many failures during training")
                    finally:
                        if lock_acquired:
                            lock.release()

            return 'completed'
Ejemplo n.º 2
0
    def start_service(self, metagraph, optimizer):
        """
        Asynchronous flask service. This may be a bit confusing why the server starts here and not init.
        It is basically because this is ran in a separate process, and when python call fork, we want to fork from this
        thread and not the master thread
        """
        app = Flask(__name__)
        self.app = app
        max_errors = self.iters
        lock = RWLock()

        server = tf.train.Server.create_local_server()
        ng = tf.Graph()
        with ng.as_default():
            tf.train.import_meta_graph(metagraph)
            loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0]
            trainable_variables = tf.trainable_variables()
            grads = tf.gradients(loss_variable, trainable_variables)
            grads = list(zip(grads, trainable_variables))
            train_op = optimizer.apply_gradients(grads)
            init = tf.global_variables_initializer()

        glob_session = tf.Session(server.target, graph=ng)
        with ng.as_default():
            with glob_session.as_default():
                glob_session.run(init)
                self.weights = tensorflow_get_weights(trainable_variables)

        cont = itertools.count()
        lock_acquired = self.acquire_lock

        @app.route('/')
        def home():
            return 'Lifeomic'

        @app.route('/parameters', methods=['GET'])
        def get_parameters():
            if lock_acquired:
                lock.acquire_read()
            vs = pickle.dumps(self.weights)
            if lock_acquired:
                lock.release()
            return vs

        @app.route('/update', methods=['POST'])
        def update_parameters():
            with ng.as_default():
                gradients = pickle.loads(request.data)
                nu_feed = {}
                for x, grad_var in enumerate(grads):
                    nu_feed[grad_var[0]] = gradients[x]

                if lock_acquired:
                    lock.acquire_write()

                with glob_session.as_default():
                    try:
                        glob_session.run(train_op, feed_dict=nu_feed)
                        self.weights = tensorflow_get_weights(trainable_variables)
                    except:
                        error_cnt = cont.next()
                        if error_cnt >= max_errors:
                            raise Exception("Too many failures during training")
                    finally:
                        if lock_acquired:
                            lock.release()

            return 'completed'

        self.app.run(host='0.0.0.0', use_reloader=False, threaded=True, port=5000)