def test_spawned_streamer(self): # Spawn releases 4 gpu worker processes streamer = Streamer(self.vision_model.batch_prediction, batch_size=8, worker_num=4, cuda_devices=(0, 1, 2, 3)) single_predict = streamer.predict(self.input_batch) assert single_predict == self.single_output batch_predict = streamer.predict(self.input_batch * BATCH_SIZE) assert batch_predict == self.batch_output
def test_spawned_streamer(): # Spawn releases 4 gpu worker processes streamer = Streamer(vision_model.batch_prediction, batch_size=8, worker_num=4) single_predict = streamer.predict(input_batch) assert single_predict == single_output batch_predict = streamer.predict(input_batch * BATCH_SIZE) assert batch_predict == batch_output
def main(): mp.set_start_method("spawn", force=True) batch_size = 64 model = TextInfillingModel() # streamer = ThreadedStreamer(model.predict, batch_size=max_batch, max_latency=0.1) streamer = Streamer(ManagedBertModel, batch_size=batch_size, max_latency=0.1, worker_num=4, cuda_devices=(0, 1, 2, 3)) streamer._wait_for_worker_ready() # streamer = RedisStreamer() text = "Happy birthday to [MASK]" num_epochs = 100 total_steps = batch_size * num_epochs t_start = time.time() for i in tqdm(range(num_epochs)): output = model.predict([text]) t_end = time.time() print('model prediction time', t_end - t_start) t_start = time.time() for i in tqdm(range(num_epochs)): output = model.predict([text] * batch_size) t_end = time.time() print('[batched]sentences per second', total_steps / (t_end - t_start)) t_start = time.time() xs = [] for i in range(total_steps): future = streamer.submit([text]) xs.append(future) for future in tqdm(xs): # 先拿到所有future对象,再等待异步返回 output = future.result(timeout=20) t_end = time.time() print('[streamed]sentences per second', total_steps / (t_end - t_start))
topk_idxs = topk(logits, k=None) topk_labels = [[[labels[idx[0]], idx[1]] for idx in idxs] for idxs in topk_idxs] return topk_labels @app.route('/batch', methods=['POST']) def batch(): sent = request.form.getlist('sent') start = request.form.getlist('start') end = request.form.getlist('end') start = [int(i) for i in start] end = [int(i) for i in end] batch = [] batch = [[sent[i], start[i], end[i]] for i in range(len(sent))] topk_labels = streamer.predict(batch) resp = topk_labels return jsonify(resp) if __name__ == "__main__": streamer = Streamer(Model, batch_size=16, max_latency=0.1, worker_num=8, cuda_devices=(0, 1)) WSGIServer(("0.0.0.0", 3101), app).serve_forever()
inputs = request.form.getlist("s") outputs = model.predict(inputs) return jsonify(outputs) @app.route("/stream", methods=["POST"]) def stream_predict(): inputs = request.form.getlist("s") outputs = streamer.predict(inputs) return jsonify(outputs) if __name__ == "__main__": import multiprocessing as mp mp.freeze_support() mp.set_start_method("spawn", force=True) streamer = Streamer(ManagedBertModel, batch_size=64, max_latency=0.1, worker_num=8, cuda_devices=(0, 1, 2, 3)) # ThreadedStreamer for comparison # model = ManagedBertModel(None) # model.init_model() # streamer = ThreadedStreamer(model.predict, batch_size=64, max_latency=0.1) from gevent.pywsgi import WSGIServer WSGIServer(("0.0.0.0", 5005), app).serve_forever()