def test_threaded_streamer(self): streamer = ThreadedStreamer(self.vision_model.batch_prediction, batch_size=8) single_predict = streamer.predict(self.input_batch) assert single_predict == self.single_output batch_predict = streamer.predict(self.input_batch * BATCH_SIZE) assert batch_predict == self.batch_output streamer.destroy_workers()
def test_future_api(self): streamer = ThreadedStreamer(self.vision_model.batch_prediction, batch_size=8) xs = [] for i in range(BATCH_SIZE): future = streamer.submit(self.input_batch) xs.append(future) batch_predict = [] # Get all instances of future object and wait for asynchronous responses. for future in xs: batch_predict.extend(future.result()) assert batch_predict == self.batch_output streamer.destroy_workers()