def score_batches(coord: Coordinator, pp: ParallelParse): t_sys = pp.transition_system while True: batch = pp.batches.get() # Score states and move to next states, based on the scoring batch_scores = pp.score_batch(batch) pp.advance_batch(batch, batch_scores) # Put non-finished sentences back in scoring queue queue_again = [sent_id for sent_id in batch.ids if not t_sys.is_final(pp.states[sent_id])] pp.num_left -= (len(batch.ids) - len(queue_again)) for sent_id in queue_again: pp.needs_scoring.put(sent_id) # Wake the batch prepare thread up and give it a chance to enqueue # a smaller batch if need be if len(queue_again) == 0 and pp.num_left > 0: pp.needs_scoring.put(None) # Our work here is done. Stop the other thread. if pp.num_left == 0: coord.request_stop() pp.needs_scoring.put(None) break
def _run_threads(self): # Setup threads coord = Coordinator() t_prepare_batch = Thread(target=prepare_batches, args=[coord, self]) t_score_batch = Thread(target=score_batches, args=[coord, self]) threads = [t_prepare_batch, t_score_batch] for t in threads: t.start() coord.join(threads)
def prepare_batches(coord: Coordinator, pp: ParallelParse): def build_batch(sent_ids): sents = [pp.sentences[sent_id] for sent_id in sent_ids] states = [pp.states[sent_id] for sent_id in sent_ids] allowed_list = pp.allowed_batch(states) return SentenceBatch(ids=sent_ids, sents=sents, states=states, feed_dict=pp.prepare_feed(sents, states), action_costs_list=pp.action_costs(states, sents, allowed_list), allowed_list=allowed_list ) sent_ids = [] while True: item = pp.needs_scoring.get() if item is None: if coord.should_stop(): break else: if len(sent_ids) >= pp.num_left: pp.batches.put(build_batch(sent_ids)) sent_ids = [] else: sent_ids.append(item) if len(sent_ids) >= min(pp.batch_size, pp.num_left): pp.batches.put(build_batch(sent_ids)) sent_ids = []