示例#1
0
    def decode(self, src_sentence, trgt_sentence):
        self.trgt_sentence = trgt_sentence + [utils.EOS_ID]
        self.initialize_predictor(src_sentence)

        hypo = PartialHypothesis(self.get_predictor_states(), self.calculate_stats)
        while hypo.get_last_word() != utils.EOS_ID:
            self._expand_hypo(hypo)
                
        hypo.score = self.get_adjusted_score(hypo)
        self.add_full_hypo(hypo.generate_full_hypothesis())
        return self.get_full_hypos_sorted()
示例#2
0
    def decode(self, src_sentence, seed=0):
        self.initialize_predictor(src_sentence)
        hypos = [
            PartialHypothesis(copy.deepcopy(self.get_predictor_states()))
            for i in range(self.nbest)
        ]

        t = 0
        while hypos and t < self.max_len:
            next_hypos = []
            for sen_seed, hypo in enumerate(hypos):
                if hypo.get_last_word() == utils.EOS_ID:
                    hypo.score = self.get_adjusted_score(hypo)
                    self.add_full_hypo(hypo.generate_full_hypothesis())
                else:
                    self._expand_hypo(hypo, seed=seed + sen_seed)
                    next_hypos.append(hypo)
            hypos = next_hypos
            t += 1

        for hypo in hypos:
            hypo.score = self.get_adjusted_score(hypo)
            self.add_full_hypo(hypo.generate_full_hypothesis())

        return self.get_full_hypos_sorted()
示例#3
0
    def decode(self, src_sentence):
        """Decodes a single source sentence using A* search. """
        self.initialize_predictor(src_sentence)
        self.lower_bound = self.get_empty_hypo(
        ) if self.use_lower_bound else None

        self.cur_capacity = self.capacity
        open_set = MinMaxHeap(
            reserve=self.capacity) if self.capacity > 0 else []
        self.push(
            open_set, 0.0,
            PartialHypothesis(self.get_predictor_states(),
                              self.calculate_stats))

        while open_set:
            c, hypo = self.pop(open_set)
            if hypo.get_last_word() == utils.EOS_ID:  # Found best hypothesis
                hypo.score = self.get_adjusted_score(hypo)
                self.add_full_hypo(hypo.generate_full_hypothesis())
                if len(self.full_hypos
                       ) == self.nbest:  # if we have enough hypos
                    return self.get_full_hypos_sorted()
                self.cur_capacity -= 1
                continue

            if len(hypo) == self.max_len:  #discard and continue
                continue

            for next_hypo in self._expand_hypo(hypo, self.capacity):
                score = self.get_adjusted_score(next_hypo)
                self.push(open_set, score, next_hypo)

        if not self.full_hypos:
            self.add_full_hypo(self.lower_bound.generate_full_hypothesis())
        return self.get_full_hypos_sorted()
示例#4
0
文件: greedy.py 项目: rycolab/bfbs
    def decode(self, src_sentence):
        self.initialize_predictor(src_sentence)
        hypothesis = PartialHypothesis(self.get_predictor_states())
        while hypothesis.get_last_word() != utils.EOS_ID and len(
                hypothesis) < self.max_len:
            ids, posterior, original_posterior = self.apply_predictor(
                hypothesis if self.gumbel else None, 1)
            trgt_word = ids[0]
            if self.gumbel:
                hypothesis.base_score += original_posterior[0]
                hypothesis.score_breakdown.append(original_posterior[0])
            else:
                hypothesis.score += posterior[0]
                hypothesis.score_breakdown.append(posterior[0])
            hypothesis.trgt_sentence.append(trgt_word)

            self.consume(trgt_word)
        self.add_full_hypo(hypothesis.generate_full_hypothesis())
        return self.full_hypos
示例#5
0
    def initialize_order_ds(self):
        self.queues = [MinMaxHeap() for k in range(self.max_len + 1)]
        self.queue_order = PointerQueue([0.0], reserve=self.max_len)
        self.time_sync = defaultdict(lambda: self.beam
                                     if self.beam > 0 else utils.INF)

        # Initialize BOS hypothesis
        self.queues[0].insert(
            (0.0, PartialHypothesis(self.get_predictor_states())))
        self.time_sync[0] = 1
示例#6
0
 def _get_initial_hypos(self):
     """Get the list of initial ``PartialHypothesis``. """
     bos_hypo = PartialHypothesis(self.get_predictor_states())
     hypos = self._expand_hypo(bos_hypo, self.beam_size)
     inds = list(np.cumsum(self.group_sizes))
     return [hypos[a:b] for a, b in zip([0] + inds[:-1], inds)]
示例#7
0
 def _get_initial_hypos(self):
     """Get the list of initial ``PartialHypothesis``. """
     return [PartialHypothesis(self.get_predictor_states())]
示例#8
0
 def _get_initial_hypos(self):
     """Get the list of initial ``PartialHypothesis``. """
     return [[PartialHypothesis(copy.deepcopy(self.get_predictor_states()), 
                                 self.calculate_stats)] for i in range(self.num_groups)]