예제 #1
0
def score_batches(coord: Coordinator, pp: ParallelParse):
    t_sys = pp.transition_system
    while True:
        batch = pp.batches.get()

        # Score states and move to next states, based on the scoring
        batch_scores = pp.score_batch(batch)
        pp.advance_batch(batch, batch_scores)

        # Put non-finished sentences back in scoring queue
        queue_again = [sent_id for sent_id in batch.ids
                       if not t_sys.is_final(pp.states[sent_id])]
        pp.num_left -= (len(batch.ids) - len(queue_again))
        for sent_id in queue_again:
            pp.needs_scoring.put(sent_id)

        # Wake the batch prepare thread up and give it a chance to enqueue
        # a smaller batch if need be
        if len(queue_again) == 0 and pp.num_left > 0:
            pp.needs_scoring.put(None)

        # Our work here is done. Stop the other thread.
        if pp.num_left == 0:
            coord.request_stop()
            pp.needs_scoring.put(None)
            break
예제 #2
0
    def _run_threads(self):
        # Setup threads
        coord = Coordinator()

        t_prepare_batch = Thread(target=prepare_batches, args=[coord, self])
        t_score_batch = Thread(target=score_batches, args=[coord, self])

        threads = [t_prepare_batch, t_score_batch]

        for t in threads:
            t.start()
        coord.join(threads)
예제 #3
0
def prepare_batches(coord: Coordinator, pp: ParallelParse):
    def build_batch(sent_ids):
        sents = [pp.sentences[sent_id] for sent_id in sent_ids]
        states = [pp.states[sent_id] for sent_id in sent_ids]
        allowed_list = pp.allowed_batch(states)

        return SentenceBatch(ids=sent_ids,
                             sents=sents,
                             states=states,
                             feed_dict=pp.prepare_feed(sents, states),
                             action_costs_list=pp.action_costs(states, sents, allowed_list),
                             allowed_list=allowed_list
                             )
    sent_ids = []
    while True:
        item = pp.needs_scoring.get()
        if item is None:
            if coord.should_stop():
                break
            else:
                if len(sent_ids) >= pp.num_left:
                    pp.batches.put(build_batch(sent_ids))
                    sent_ids = []
        else:
            sent_ids.append(item)
            if len(sent_ids) >= min(pp.batch_size, pp.num_left):
                pp.batches.put(build_batch(sent_ids))
                sent_ids = []