コード例 #1
0
class GreedyHeuristic(Heuristic):
    """This heuristic performs greedy decoding to get future cost 
    estimates. This is expensive but can lead to very close estimates.
    """
    def __init__(self, decoder_args, cache_estimates=True):
        """Creates a new ``GreedyHeuristic`` instance. The greedy 
        heuristic performs full greedy decoding from the current
        state to get accurate cost estimates. However, this can be very
        expensive.
        
        Args:
            decoder_args (object): Decoder configuration passed through
                                   from the configuration API.
            cache_estimates (bool): Set to true to enable a cache for
                                    predictor states which have been
                                    visited during the greedy decoding.
        """
        super(GreedyHeuristic, self).__init__()
        self.cache_estimates = cache_estimates
        self.decoder = GreedyDecoder(decoder_args)
        self.cache = SimpleTrie()

    def set_predictors(self, predictors):
        """Override ``Decoder.set_predictors`` to redirect the 
        predictors to ``self.decoder``
        """
        self.predictors = predictors
        self.decoder.predictors = predictors

    def initialize(self, src_sentence):
        """Initialize the cache. """
        self.cache = SimpleTrie()

    def estimate_future_cost(self, hypo):
        """Estimate the future cost by full greedy decoding. If
        ``self.cache_estimates`` is enabled, check cache first
        """
        if self.cache_estimates:
            return self.estimate_future_cost_with_cache(hypo)
        else:
            return self.estimate_future_cost_without_cache(hypo)

    def estimate_future_cost_with_cache(self, hypo):
        """Enabled cache... """
        cached_cost = self.cache.get(hypo.trgt_sentence)
        if not cached_cost is None:
            return cached_cost
        old_states = self.decoder.get_predictor_states()
        self.decoder.set_predictor_states(copy.deepcopy(old_states))
        # Greedy decoding
        trgt_word = hypo.trgt_sentence[-1]
        scores = []
        words = []
        while trgt_word != utils.EOS_ID:
            self.decoder.consume(trgt_word)
            posterior, _ = self.decoder.apply_predictors()
            trgt_word = utils.argmax(posterior)
            scores.append(posterior[trgt_word])
            words.append(trgt_word)
        # Update cache using scores and words
        for i in xrange(1, len(scores)):
            self.cache.add(hypo.trgt_sentence + words[:i], -sum(scores[i:]))
        # Reset predictor states
        self.decoder.set_predictor_states(old_states)
        return -sum(scores)

    def estimate_future_cost_without_cache(self, hypo):
        """Disabled cache... """
        old_states = self.decoder.get_predictor_states()
        self.decoder.set_predictor_states(copy.deepcopy(old_states))
        # Greedy decoding
        trgt_word = hypo.trgt_sentence[-1]
        score = 0.0
        while trgt_word != utils.EOS_ID:
            self.decoder.consume(trgt_word)
            posterior, _ = self.decoder.apply_predictors()
            trgt_word = utils.argmax(posterior)
            score += posterior[trgt_word]
        # Reset predictor states
        self.decoder.set_predictor_states(old_states)
        return -score
コード例 #2
0
ファイル: bow.py プロジェクト: ucam-smt/sgnmt
class BagOfWordsPredictor(Predictor):
    """This predictor is similar to the forced predictor, but it does
    not enforce the word order in the reference. Therefore, it assigns
    1 to all hypotheses which have the words in the reference in any 
    order, and -inf to all other hypos.
    """
    
    def __init__(self, 
                 trg_test_file, 
                 accept_subsets=False,
                 accept_duplicates=False,
                 heuristic_scores_file="",
                 collect_stats_strategy='best',
                 heuristic_add_consumed = False, 
                 heuristic_add_remaining = True,
                 diversity_heuristic_factor = -1.0,
                 equivalence_vocab=-1):
        """Creates a new bag-of-words predictor.
        
        Args:
            trg_test_file (string): Path to the plain text file with 
                                    the target sentences. Must have the
                                    same number of lines as the number
                                    of source sentences to decode. The 
                                    word order in the target sentences
                                    is not relevant for this predictor.
            accept_subsets (bool): If true, this predictor permits
                                       EOS even if the bag is not fully
                                       consumed yet
            accept_duplicates (bool): If true, counts are not updated
                                      when a word is consumed. This
                                      means that we allow a word in a
                                      bag to appear multiple times
            heuristic_scores_file (string): Path to the unigram scores 
                                            which are used if this 
                                            predictor estimates future
                                            costs
            collect_stats_strategy (string): best, full, or all. Defines 
                                             how unigram estimates are 
                                             collected for heuristic 
            heuristic_add_consumed (bool): Set to true to add the 
                                           difference between actual
                                           partial score and unigram
                                           estimates of consumed words
                                           to the predictor heuristic
            heuristic_add_remaining (bool): Set to true to add the sum
                                            of unigram scores of words
                                            remaining in the bag to the
                                            predictor heuristic
            diversity_heuristic_factor (float): Factor for diversity
                                                heuristic which 
                                                penalizes hypotheses
                                                with the same bag as
                                                full hypos
            equivalence_vocab (int): If positive, predictor states are
                                     considered equal if the the 
                                     remaining words within that vocab
                                     and OOVs regarding this vocab are
                                     the same. Only relevant when using
                                     hypothesis recombination
        """
        super(BagOfWordsPredictor, self).__init__()
        with open(trg_test_file) as f:
            self.lines = f.read().splitlines()
        if heuristic_scores_file:
            self.estimates = FileUnigramTable(heuristic_scores_file)
        elif collect_stats_strategy == 'best':
            self.estimates = BestStatsUnigramTable()
        elif collect_stats_strategy == 'full':
            self.estimates = FullStatsUnigramTable()
        elif collect_stats_strategy == 'all':
            self.estimates = AllStatsUnigramTable()
        else:
            logging.error("Unknown statistics collection strategy")
        self.accept_subsets = accept_subsets
        self.accept_duplicates = accept_duplicates
        self.heuristic_add_consumed = heuristic_add_consumed
        self.heuristic_add_remaining = heuristic_add_remaining
        self.equivalence_vocab = equivalence_vocab
        if accept_duplicates and not accept_subsets:
            logging.error("You enabled bow_accept_duplicates but not bow_"
                          "accept_subsets. Therefore, the bow predictor will "
                          "never accept end-of-sentence and could cause "
                          "an infinite loop in the search strategy.")
        self.diversity_heuristic_factor = diversity_heuristic_factor
        self.diverse_heuristic = (diversity_heuristic_factor > 0.0)
          
    def get_unk_probability(self, posterior):
        """Returns negative infinity unconditionally: Words which are
        not in the target sentence have assigned probability 0 by
        this predictor.
        """
        return NEG_INF
    
    def predict_next(self):
        """If the bag is empty, the only allowed symbol is EOS. 
        Otherwise, return the list of keys in the bag.
        """
        if not self.bag: # Empty bag
            return {utils.EOS_ID : 0.0}
        ret = {w : 0.0 for w in self.bag.iterkeys()}
        if self.accept_subsets:
            ret[utils.EOS_ID] = 0.0
        return ret
    
    def initialize(self, src_sentence):
        """Creates a new bag for the current target sentence..
        
        Args:
            src_sentence (list):  Not used
        """
        self.best_hypo_score = NEG_INF
        self.bag = {}
        for w in self.lines[self.current_sen_id].strip().split(): 
            int_w = int(w)
            self.bag[int_w] = self.bag.get(int_w, 0) + 1
        self.full_bag = dict(self.bag)
        
    def consume(self, word):
        """Updates the bag by deleting the consumed word.
        
        Args:
            word (int): Next word to consume
        """
        if word == utils.EOS_ID:
            self.bag = {}
            return
        if not word in self.bag:
            logging.warn("Consuming word which is not in bag-of-words!")
            return
        cnt = self.bag.pop(word)
        if cnt > 1 and not self.accept_duplicates:
            self.bag[word] = cnt - 1
    
    def get_state(self):
        """State of this predictor is the current bag """
        return self.bag
    
    def set_state(self, state):
        """State of this predictor is the current bag """
        self.bag = state

    def initialize_heuristic(self, src_sentence):
        """Calls ``reset`` of the used unigram table with estimates
        ``self.estimates`` to clear all statistics from the previous
        sentence
        
        Args:
            src_sentence (list): Not used
        """
        self.estimates.reset()
        if self.diverse_heuristic:
            self.explored_bags = SimpleTrie()
    
    def notify(self, message, message_type = MESSAGE_TYPE_DEFAULT):
        """This gets called if this predictor observes the decoder. It
        updates unigram heuristic estimates via passing through this
        message to the unigram table ``self.estimates``.
        """
        self.estimates.notify(message, message_type)
        if self.diverse_heuristic and message_type == MESSAGE_TYPE_FULL_HYPO:
            self._update_explored_bags(message)
    
    def _update_explored_bags(self, hypo):
        """This is called if diversity heuristic is enabled. It updates
        ``self.explored_bags``
        """
        sen = hypo.trgt_sentence
        for l in xrange(len(sen)):
            key = sen[:l]
            key.sort()
            cnt = self.explored_bags.get(key)
            if not cnt:
                cnt = 0.0
            self.explored_bags.add(key, cnt + 1.0)
                    
    def estimate_future_cost(self, hypo):
        """The bow predictor comes with its own heuristic function. We
        use the sum of scores of the remaining words as future cost 
        estimator. 
        """
        acc = 0.0
        if self.heuristic_add_remaining:
            remaining = dict(self.full_bag)
            remaining[utils.EOS_ID] = 1
            for w in hypo.trgt_sentence:
                remaining[w] -= 1
            acc -= sum([cnt*self.estimates.estimate(w) 
                            for w,cnt in remaining.iteritems()])
        if self.diverse_heuristic:
            key = list(hypo.trgt_sentence)
            key.sort()
            cnt = self.explored_bags.get(key)
            if cnt:
                acc += cnt * self.diversity_heuristic_factor
        if self.heuristic_add_consumed:
            acc -= hypo.score - sum([self.estimates.estimate(w, -1000.0)
                            for w in hypo.trgt_sentence])
        return acc
    
    def _get_unk_bag(self, org_bag):
        if self.equivalence_vocab <= 0:
            return org_bag
        unk_bag = {}
        for word,cnt in org_bag.iteritems():
            idx = word if word < self.equivalence_vocab else utils.UNK_ID
            unk_bag[idx] = unk_bag.get(idx, 0) + cnt
        return unk_bag
    
    def is_equal(self, state1, state2):
        """Returns true if the bag is the same """
        return self._get_unk_bag(state1) == self._get_unk_bag(state2) 
コード例 #3
0
class BlocksNMTPredictor(Predictor):
    """This is the neural machine translation predictor. The predicted
    posteriors are equal to the distribution generated by the decoder
    network in NMT. This predictor heavily relies on the NMT example in
    blocks. Note that this predictor cannot be used in combination with
    a target side sparse feature map. See 
    ``BlocksUnboundedNMTPredictor`` for that case.
    """
    
    def __init__(self, nmt_model_path, gnmt_beta, enable_cache, config):
        """Creates a new NMT predictor.
        
        Args:
            nmt_model_path (string):  Path to the NMT model file (.npz)
            gnmt_beta (float): If greater than 0.0, add a Google NMT
                               style coverage penalization term (Wu et
                               al., 2016) to the predictive scores
            enable_cache (bool):  The NMT predictor usually has a very
                                  limited vocabulary size, and a large
                                  number of UNKs in hypotheses. This
                                  enables reusing already computed
                                  predictor states for hypotheses which
                                  differ only by NMT OOV words.
            config (dict): NMT configuration
        
        Raises:
            ValueError. If a target sparse feature map is defined
        """
        super(BlocksNMTPredictor, self).__init__()
        self.gnmt_beta = gnmt_beta
        self.add_gnmt_coverage_term = gnmt_beta > 0.0
        self.config = copy.deepcopy(config)
        self.enable_cache = enable_cache
        self.set_up_predictor(nmt_model_path)
        self.src_eos = self.src_sparse_feat_map.word2dense(utils.EOS_ID)
    
    def set_up_predictor(self, nmt_model_path):
        """Initializes the predictor with the given NMT model. Code 
        following ``blocks.machine_translation.main``. 
        """
        self.src_vocab_size = self.config['src_vocab_size']
        self.trgt_vocab_size = self.config['trg_vocab_size']
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path,
                              self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()
            
        self.best_models = []
        self.val_bleu_curve = []
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            logging.fatal("Cannot use bounded vocabulary predictor with "
                          "a target sparse feature map. Ignoring...")
        self.search_algorithm = MyopticSearch(samples=self.nmt_model.samples)
        self.search_algorithm.compile()
        

    def initialize(self, src_sentence):
        """Runs the encoder network to create the source annotations
        for the source sentence. If the cache is enabled, empty the
        cache.
        
        Args:
            src_sentence (list): List of word ids without <S> and </S>
                                 which represent the source sentence.
        """
        self.contexts = None
        self.states = None 
        self.posterior_cache = SimpleTrie()
        self.states_cache = SimpleTrie()
        self.consumed = []
        seq = self.src_sparse_feat_map.words2dense(
                    utils.oov_to_unk(src_sentence,
                                     self.src_vocab_size)) + [self.src_eos]
        if self.src_sparse_feat_map.dim > 1: # sparse src feats
            input_ = np.transpose(np.tile(seq, (1, 1, 1)), (2,0,1))
        else: # word ids on the source side
            input_ = np.tile(seq, (1, 1))
        
        input_values={self.nmt_model.sampling_input: input_}
        self.contexts, self.states, _ = self.search_algorithm.compute_initial_states_and_contexts(
            input_values)
        self.attention_records = (1 + len(src_sentence)) * [0.0]
    
    def is_history_cachable(self):
        """Returns true if cache is enabled and history contains UNK """
        if not self.enable_cache:
            return False
        for w in self.consumed:
            if w == utils.UNK_ID:
                return True
        return False

    def predict_next(self):
        """Uses cache or runs the decoder network to get the 
        distribution over the next target words.
        
        Returns:
            np array. Full distribution over the entire NMT vocabulary
            for the next target token.
        """
        use_cache = self.is_history_cachable()
        if use_cache:
            posterior = self.posterior_cache.get(self.consumed)
            if not posterior is None:
                logging.debug("Loaded NMT posterior from cache for %s" % 
                                self.consumed)
                return self._add_gnmt_beta(posterior)
        # logprobs are negative log probs, i.e. greater than 0
        logprobs = self.search_algorithm.compute_logprobs(self.contexts,
                                                          self.states)
        posterior = np.multiply(logprobs[0], -1.0)
        if use_cache:
            self.posterior_cache.add(self.consumed, posterior)
        return self._add_gnmt_beta(posterior)
    
    def _add_gnmt_beta(self, posterior):
        """Adds the GNMT coverage penalization term to EOS in 
        ``posterior``
        """
        if self.add_gnmt_coverage_term:
            posterior[utils.EOS_ID] += self.gnmt_beta * sum([np.log(max(0.0001,
                                                                        p)) 
                                for p in self.attention_records if p < 1.0])
        return posterior
        
    def get_unk_probability(self, posterior):
        """Returns the UNK probability defined by NMT. """
        return posterior[utils.UNK_ID] if len(posterior) > utils.UNK_ID else NEG_INF
    
    def consume(self, word):
        """Feeds back ``word`` to the decoder network. This includes 
        embedding of ``word``, running the attention network and update
        the recurrent decoder layer.
        """
        if word >= self.trgt_vocab_size:
            word = utils.UNK_ID
        self.consumed.append(word)
        use_cache = self.is_history_cachable()
        if use_cache:
            s = self.states_cache.get(self.consumed)
            if not s is None:
                logging.debug("Loaded NMT decoder states from cache for %s" % 
                                    self.consumed)
                self.states = copy.deepcopy(s)
                return
        self.states.update(self.search_algorithm.compute_next_states(
                self.contexts, self.states, [word]))
        if use_cache:
            self.states_cache.add(self.consumed, copy.deepcopy(self.states))
        if self.add_gnmt_coverage_term: # Keep track of attentions
            for pos,att in enumerate(self.states['weights'][0]):
                self.attention_records[pos] += att
    
    def get_state(self):
        """The NMT predictor state consists of the decoder network 
        state, and (for caching) the current history of consumed words
        """
        return self.states,self.consumed,self.attention_records
    
    def set_state(self, state):
        """Set the NMT predictor state. """
        self.states,self.consumed,self.attention_records = state

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        _,consumed1,_ = state1
        _,consumed2,_ = state2
        return consumed1 == consumed2
コード例 #4
0
ファイル: length.py プロジェクト: Jack44Wang/sgnmt
class NgramCountPredictor(Predictor):
    """This predictor counts the number of n-grams in hypotheses. n-gram
    posteriors are loaded from a file. The predictor score is the sum of
    all n-gram posteriors in a hypothesis. """
    def __init__(self, path, order=0, discount_factor=-1.0):
        """Creates a new ngram count predictor instance.
        
        Args:
            path (string): Path to the n-gram posteriors. File format:
                           <ngram> : <score> (one ngram per line). Use
                           placeholder %d for sentence id.
            order (int): If positive, count n-grams of the specified
                         order. Otherwise, count all n-grams
            discount_factor (float): If non-negative, discount n-gram
                                     posteriors by this factor each time 
                                     they are consumed 
        """
        super(NgramCountPredictor, self).__init__()
        self.path = path
        self.order = order
        self.discount_factor = discount_factor

    def get_unk_probability(self, posterior):
        """Always return 0.0 """
        return 0.0

    def predict_next(self):
        """Composes the posterior vector by collecting all ngrams which
        are consistent with the current history.
        """
        posterior = {}
        for i in reversed(range(len(self.cur_history) + 1)):
            scores = self.ngrams.get(self.cur_history[i:])
            if scores:
                factors = False
                if self.discount_factor >= 0.0:
                    factors = self.discounts.get(self.cur_history[i:])
                if not factors:
                    for w, score in scores.iteritems():
                        posterior[w] = posterior.get(w, 0.0) + score
                else:
                    for w, score in scores.iteritems():
                        posterior[w] = posterior.get(w, 0.0) +  \
                                       factors.get(w, 1.0) * score
        return posterior

    def _load_posteriors(self, path):
        """Sets up self.max_history_len and self.ngrams """
        self.max_history_len = 0
        self.ngrams = SimpleTrie()
        with open(path) as f:
            for line in f:
                ngram, score = line.split(':')
                words = [int(w) for w in ngram.strip().split()]
                if self.order > 0 and len(words) != self.order:
                    continue
                hist = words[:-1]
                last_word = words[-1]
                if last_word == utils.GO_ID:
                    continue
                self.max_history_len = max(self.max_history_len, len(hist))
                p = self.ngrams.get(hist)
                if p:
                    p[last_word] = float(score.strip())
                else:
                    self.ngrams.add(hist, {last_word: float(score.strip())})

    def initialize(self, src_sentence):
        """Loads n-gram posteriors and resets history.
        
        Args:
            src_sentence (list): not used
        """
        self._load_posteriors(
            utils.get_path(self.path, self.current_sen_id + 1))
        self.cur_history = [utils.GO_ID]
        self.discounts = SimpleTrie()

    def consume(self, word):
        """Adds ``word`` to the current history. Shorten if the extended
        history exceeds ``max_history_len``.
        
        Args:
            word (int): Word to add to the history.
        """
        self.cur_history.append(word)
        if len(self.cur_history) > self.max_history_len:
            self.cur_history = self.cur_history[-self.max_history_len:]
        if self.discount_factor >= 0.0:
            for i in range(len(self.cur_history)):
                key = self.cur_history[i:-1]
                factors = self.discounts.get(key)
                if not factors:
                    factors = {word: self.discount_factor}
                else:
                    factors[word] = factors.get(word,
                                                1.0) * self.discount_factor
                self.discounts.add(key, factors)

    def get_state(self):
        """Current history is the predictor state """
        return self.cur_history, self.discounts

    def set_state(self, state):
        """Current history is the predictor state """
        self.cur_history, self.discounts = state

    def reset(self):
        """Empty method. """
        pass

    def is_equal(self, state1, state2):
        """Hypothesis recombination is
        not supported if discounting is enabled.
        """
        if self.discount_factor >= 0.0:
            return False
        hist1 = state1[0]
        hist2 = state2[0]
        if hist1 == hist2:  # Return true if histories match
            return True
        if len(hist1) > len(hist2):
            hist_long = hist1
            hist_short = hist2
        else:
            hist_long = hist2
            hist_short = hist1
        min_len = len(hist_short)
        for n in xrange(1, min_len + 1):  # Look up non matching in self.ngrams
            key1 = hist1[-n:]
            key2 = hist2[-n:]
            if key1 != key2:
                if self.ngrams.get(key1) or self.ngrams.get(key2):
                    return False
        for n in xrange(min_len + 1, len(hist_long) + 1):
            if self.ngrams.get(hist_long[-n:]):
                return False
        return True
コード例 #5
0
ファイル: tf_t2t.py プロジェクト: strategist922/sgnmt
class EditT2TPredictor(_BaseTensor2TensorPredictor):
    """This predictor can be used for T2T models conditioning on the
    full target sentence. The predictor state is a full target sentence.
    The state can be changed by insertions, substitutions, and deletions
    of single tokens, whereas each operation is encoded as SGNMT token
    in the following way:

      1xxxyyyyy: Insert the token yyyyy at position xxx.
      2xxxyyyyy: Replace the xxx-th word with the token yyyyy.
      3xxx00000: Delete the xxx-th token.
    """

    INS_OFFSET = 100000000
    SUB_OFFSET = 200000000
    DEL_OFFSET = 300000000

    POS_FACTOR = 100000
    MAX_SEQ_LEN = 999

    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 model_name,
                 problem_name,
                 hparams_set_name,
                 trg_test_file,
                 beam_size,
                 t2t_usr_dir,
                 checkpoint_dir,
                 t2t_unk_id=None,
                 n_cpu_threads=-1,
                 max_terminal_id=-1,
                 pop_id=-1):
        """Creates a new edit T2T predictor. This constructor is
        similar to the constructor of T2TPredictor but creates a
        different computation graph which retrieves scores at each
        target position, not only the last one.
        
        Args:
            src_vocab_size (int): Source vocabulary size.
            trg_vocab_size (int): Target vocabulary size.
            model_name (string): T2T model name.
            problem_name (string): T2T problem name.
            hparams_set_name (string): T2T hparams set name.
            trg_test_file (string): Path to a plain text file with
                initial target sentences. Can be empty.
            beam_size (int): Determines how many substitutions and
                insertions are considered at each position.
            t2t_usr_dir (string): See --t2t_usr_dir in tensor2tensor.
            checkpoint_dir (string): Path to the T2T checkpoint 
                                     directory. The predictor will load
                                     the top most checkpoint in the 
                                     `checkpoints` file.
            t2t_unk_id (int): If set, use this ID to get UNK scores. If
                              None, UNK is always scored with -inf.
            n_cpu_threads (int): Number of TensorFlow CPU threads.
            max_terminal_id (int): If positive, maximum terminal ID. Needs to
                be set for syntax-based T2T models.
            pop_id (int): If positive, ID of the POP or closing bracket symbol.
                Needs to be set for syntax-based T2T models.
        """
        super(EditT2TPredictor, self).__init__(t2t_usr_dir, 
                                               checkpoint_dir, 
                                               src_vocab_size,
                                               trg_vocab_size,
                                               t2t_unk_id, 
                                               n_cpu_threads,
                                               max_terminal_id,
                                               pop_id)
        if not model_name or not problem_name or not hparams_set_name:
            logging.fatal(
                "Please specify t2t_model, t2t_problem, and t2t_hparams_set!")
            raise AttributeError
        if trg_vocab_size >= EditT2TPredictor.POS_FACTOR:
            logging.fatal("Target vocabulary size (%d) must be less than %d!"
                          % (trg_vocab_size, EditT2TPredictor.POS_FACTOR))
            raise AttributeError
        self.beam_size = max(1, beam_size // 10) + 1
        self.batch_size = 2048 # TODO(fstahlberg): Move to config
        self.initial_trg_sentences = None
        if trg_test_file: 
            self.initial_trg_sentences = []
            with open(trg_test_file) as f:
                for line in f:
                    self.initial_trg_sentences.append(utils.oov_to_unk(
                       [int(w) for w in line.strip().split()] + [utils.EOS_ID],
                       self.trg_vocab_size, self._t2t_unk_id))
        predictor_graph = tf.Graph()
        with predictor_graph.as_default() as g:
            hparams = trainer_lib.create_hparams(hparams_set_name)
            self._add_problem_hparams(hparams, problem_name)
            translate_model = registry.model(model_name)(
                hparams, tf.estimator.ModeKeys.EVAL)
            self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None],
                                              name="sgnmt_inputs")
            self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None, None], 
                                               name="sgnmt_targets")
            shp = tf.shape(self._targets_var)
            bsz = shp[0]
            inputs = tf.tile(tf.expand_dims(self._inputs_var, 0), [bsz, 1])
            features = {"inputs": expand_input_dims_for_t2t(inputs,
                                                            batched=True), 
                        "targets": expand_input_dims_for_t2t(self._targets_var,
                                                             batched=True)}
            translate_model.prepare_features_for_infer(features)
            translate_model._fill_problem_hparams_features(features)
            logits, _ = translate_model(features)
            logits = tf.squeeze(logits, [2, 3])
            self._log_probs = log_prob_from_logits(logits)
            diag_logits = gather_2d(logits, tf.expand_dims(tf.range(bsz), 1))
            self._diag_log_probs = log_prob_from_logits(diag_logits)
            no_pad = tf.cast(tf.not_equal(
                self._targets_var, text_encoder.PAD_ID), tf.float32)
            flat_bsz = shp[0] * shp[1]
            word_scores = gather_2d(
                tf.reshape(self._log_probs, [flat_bsz, -1]),
                tf.reshape(self._targets_var, [flat_bsz, 1]))
            word_scores = tf.reshape(word_scores, (shp[0], shp[1])) * no_pad
            self._sentence_scores = tf.reduce_sum(word_scores, -1)
            self.mon_sess = self.create_session()

    def _ins_op(self, pos, token):
        """Returns a copy of trg sentence after an insertion."""
        return self.trg_sentence[:pos] + [token] + self.trg_sentence[pos:]

    def _sub_op(self, pos, token):
        """Returns a copy of trg sentence after a substitution."""
        ret = list(self.trg_sentence)
        ret[pos] = token
        return ret

    def _del_op(self, pos):
        """Returns a copy of trg sentence after a deletion."""
        return self.trg_sentence[:pos] + self.trg_sentence[pos+1:]

    def _top_n(self, scores, sort=False):
        """Sorted indices of beam_size best entries along axis 1"""
        costs = -scores
        costs[:, utils.EOS_ID] = utils.INF
        top_n_indices = np.argpartition(
            costs,
            self.beam_size, 
            axis=1)[:, :self.beam_size]
        if not sort:
            return top_n_indices
        b_indices = np.expand_dims(np.arange(top_n_indices.shape[0]), axis=1)
        sorted_indices = np.argsort(costs[b_indices, top_n_indices], axis=1)
        return top_n_indices[b_indices, sorted_indices]

    def predict_next(self):
        """Call the T2T model in self.mon_sess."""
        next_sentences = {}
        logging.debug("EditT2T: Exploring score=%f sentence=%s" 
                      % (self.cur_score, " ".join(map(str, self.trg_sentence))))
        n_trg_words = len(self.trg_sentence)
        if n_trg_words > EditT2TPredictor.MAX_SEQ_LEN:
            logging.warn("EditT2T: Target sentence exceeds maximum length (%d)"
                         % EDITT2TPredictor.MAX_SEQ_LEN)
            return {utils.EOS_ID: 0.0}
        # Substitutions
        log_probs = self.mon_sess.run(self._log_probs,
            {self._inputs_var: self.src_sentence,
             self._targets_var: [self.trg_sentence]})
        top_n = self._top_n(np.squeeze(log_probs, axis=0))
        for pos, cur_token in enumerate(self.trg_sentence[:-1]):
            offset = EditT2TPredictor.SUB_OFFSET 
            offset += EditT2TPredictor.POS_FACTOR * pos
            for token in top_n[pos]:
                if token != cur_token:
                    next_sentences[offset + token] = self._sub_op(pos, token)

        # Insertions
        if n_trg_words < EditT2TPredictor.MAX_SEQ_LEN - 1:
            ins_trg_sentences = np.full((n_trg_words, n_trg_words+1), 999)
            for pos in range(n_trg_words):
                ins_trg_sentences[pos, :pos] = self.trg_sentence[:pos]
                ins_trg_sentences[pos, pos+1:] = self.trg_sentence[pos:]
            diag_log_probs = self.mon_sess.run(self._diag_log_probs,
                {self._inputs_var: self.src_sentence,
                 self._targets_var: ins_trg_sentences})
            top_n = self._top_n(np.squeeze(diag_log_probs, axis=1))
            for pos in range(n_trg_words):
                offset = EditT2TPredictor.INS_OFFSET 
                offset += EditT2TPredictor.POS_FACTOR * pos
                for token in top_n[pos]:
                    next_sentences[offset + token] = self._ins_op(pos, token)
        # Deletions
        idx = EditT2TPredictor.DEL_OFFSET
        for pos in range(n_trg_words - 1): # -1: Do not delete EOS
            next_sentences[idx] = self._del_op(pos)
            idx += EditT2TPredictor.POS_FACTOR
        abs_scores = self._score(next_sentences, n_trg_words + 1)
        rel_scores = {i: s - self.cur_score 
                      for i, s in abs_scores.items()}
        rel_scores[utils.EOS_ID] = 0.0
        return rel_scores

    def _score(self, sentences, n_trg_words=1):
        max_n_sens = max(1, self.batch_size // n_trg_words)
        scores = {}
        batch_ids = []
        batch_sens = []
        for idx, trg_sentence in sentences.items():
            score = self.cache.get(trg_sentence)
            if score is None:
                batch_ids.append(idx)
                np_sen = np.zeros(n_trg_words, dtype=np.int)
                np_sen[:len(trg_sentence)] = trg_sentence
                batch_sens.append(np_sen)
                if len(batch_ids) >= max_n_sens:
                    self._score_single_batch(scores, batch_ids, batch_sens)
                    batch_ids = []
                    batch_sens = []
            else:
                scores[idx] = score
        self._score_single_batch(scores, batch_ids, batch_sens)
        return scores 

    def _score_single_batch(self, scores, ids, trg_sentences):
        "Score sentences and add them to scores and the cache."""
        if not ids:
            return
        batch_scores = self.mon_sess.run(self._sentence_scores,
            {self._inputs_var: self.src_sentence,
             self._targets_var: np.stack(trg_sentences)})
        for idx, sen, score in zip(ids, trg_sentences, batch_scores):
            self.cache.add(sen, score)
            scores[idx] = score

    def _update_cur_score(self):
        self.cur_score = self.cache.get(self.trg_sentence)
        if self.cur_score is None:
            scores = self._score({1: self.trg_sentence}, len(self.trg_sentence))
            self.cur_score = scores[1]
            self.cache.add(self.trg_sentence, self.cur_score)
    
    def initialize(self, src_sentence):
        """Set src_sentence, reset consumed."""
        if self.initial_trg_sentences is None:
            self.trg_sentence = [text_encoder.EOS_ID]
        else:
            self.trg_sentence = self.initial_trg_sentences[self.current_sen_id]
        self.src_sentence = utils.oov_to_unk(
            src_sentence + [text_encoder.EOS_ID], 
            self.src_vocab_size, self._t2t_unk_id)
        self.cache = SimpleTrie()
        self._update_cur_score()
        logging.debug("Initial score: %f" % self.cur_score)
   
    def consume(self, word):
        """Append ``word`` to the current history."""
        if word == utils.EOS_ID:
            return
        pos = (word // EditT2TPredictor.POS_FACTOR) \
              % (EditT2TPredictor.MAX_SEQ_LEN + 1)
        token = word % EditT2TPredictor.POS_FACTOR
        # TODO(fstahlberg): Do not hard code the following section
        op = word // 100000000  
        if op == 1:  # Insertion
            self.trg_sentence = self._ins_op(pos, token)
        elif op == 2:  # Substitution
            self.trg_sentence = self._sub_op(pos, token)
        elif op == 3:  # Deletion
            self.trg_sentence = self._del_op(pos)
        else:
            logging.warn("Invalid edit descriptor %d. Ignoring..." % word)
        self._update_cur_score()
        self.cache.add(self.trg_sentence, utils.NEG_INF)
    
    def get_state(self):
        """The predictor state is the complete target sentence."""
        return self.trg_sentence, self.cur_score
    
    def set_state(self, state):
        """The predictor state is the complete target sentence."""
        self.trg_sentence, self.cur_score = state

    def is_equal(self, state1, state2):
        """Returns true if the target sentence is the same """
        return state1[0] == state2[0]
コード例 #6
0
ファイル: tf_nmt.py プロジェクト: ucam-smt/sgnmt
class TensorFlowNMTPredictor(Predictor):
  '''Neural MT predictor'''
  def __init__(self, enable_cache, config, session):
      super(TensorFlowNMTPredictor, self).__init__()
      self.config = config
      self.session = session

      # Add missing entries in config
      if self.config['encoder'] == "bow":
        self.config['init_backward'] = False
        self.config['use_seqlen'] = False
      else:
        self.config['bow_init_const'] = False
        self.config['use_bow_mask'] = False

      # Load tensorflow model
      self.model, self.training_graph, self.encoding_graph, \
        self.single_step_decoding_graph, self.buckets = tf_model_utils.load_model(self.session, config)
      self.model.batch_size = 1  # We decode one sentence at a time.

      self.enc_out = {}
      self.decoder_input = [tf_data_utils.GO_ID]
      self.dec_state = {}
      self.bucket_id = -1
      self.num_heads = 1
      self.word_count = 0

      if config['no_pad_symbol']:
        # This needs to be set in tensorflow data_utils for correct source masks
        tf_data_utils.no_pad_symbol()
        logging.info("UNK_ID=%d" % tf_data_utils.UNK_ID)
        logging.info("PAD_ID=%d" % tf_data_utils.PAD_ID)

      self.enable_cache = enable_cache
      if self.enable_cache:
        logging.info("Cache enabled..")

  def initialize(self, src_sentence):
    # src_sentence is list of integers, without <s> and </s>
    self.enc_out = {}
    self.decoder_input = [tf_data_utils.GO_ID]
    self.dec_state = {}
    self.word_count = 0
    self.consumed = []
    self.posterior_cache = SimpleTrie()
    self.states_cache = SimpleTrie()

    src_sentence = [w if w < self.config['src_vocab_size'] else tf_data_utils.UNK_ID
                    for w in src_sentence]
    if self.config['add_src_eos']:
      src_sentence.append(tf_data_utils.EOS_ID)

    feasible_buckets = [b for b in xrange(len(self.buckets))
                        if self.buckets[b][0] >= len(src_sentence)]
    if not feasible_buckets:
      # Get a new bucket
      bucket = tf_model_utils.make_bucket(len(src_sentence))
      logging.info("Add new bucket={} and update model".format(bucket))
      self.buckets.append(bucket)
      self.model.update_buckets(self.buckets)
      self.bucket_id = len(self.buckets) - 1
    else:
      self.bucket_id = min(feasible_buckets)

    encoder_inputs, _, _, sequence_length, src_mask, bow_mask = self.training_graph.get_batch(
            {self.bucket_id: [(src_sentence, [])]}, self.bucket_id, self.config['encoder'])
    logging.info("bucket={}".format(self.buckets[self.bucket_id]))

    last_enc_state, self.enc_out = self.encoding_graph.encode(
            self.session, encoder_inputs, self.bucket_id, sequence_length)

    # Initialize decoder state with last encoder state
    self.dec_state["dec_state"] = last_enc_state
    for a in xrange(self.num_heads):
      self.dec_state["dec_attns_%d" % a] = np.zeros((1, self.enc_out['enc_v_0'].size), dtype=np.float32)

    if self.config['use_src_mask']:
      self.dec_state["src_mask"] = src_mask
      self.src_mask_orig = src_mask.copy()

    if self.config['use_bow_mask']:
      self.dec_state["bow_mask"] = bow_mask
      self.bow_mask_orig = bow_mask.copy()

  def is_history_cachable(self):
    """Returns true if cache is enabled and history contains UNK """
    if not self.enable_cache:
      return False
    for w in self.consumed:
      if w == tf_data_utils.UNK_ID:
        return True
    return False

  def predict_next(self):
    # should return list, numpy array, or dictionary
    if self.decoder_input[0] == tf_data_utils.EOS_ID: # Predict EOS
        return {tf_data_utils.EOS_ID: 0}

    use_cache = self.is_history_cachable()
    if use_cache:
      posterior = self.posterior_cache.get(self.consumed)
      if not posterior is None:
        logging.debug("Loaded NMT posterior from cache for %s" %
                      self.consumed)
        return posterior

    output, self.dec_state = self.single_step_decoding_graph.decode(self.session, self.enc_out,
                                               self.dec_state, self.decoder_input, self.bucket_id,
                                               self.config['use_src_mask'], self.word_count,
                                               self.config['use_bow_mask'])
    if use_cache:
      self.posterior_cache.add(self.consumed, output[0])

    return output[0]

  def get_unk_probability(self, posterior):
    # posterior is the returned value of the last predict_next call
    return posterior[utils.UNK_ID] if len(posterior) > 1 else float("-inf")

  def consume(self, word):
    if word >= self.config['trg_vocab_size']:
      word = tf_data_utils.UNK_ID  # history is kept according to nmt vocab
    self.consumed.append(word)

    use_cache = self.is_history_cachable()
    if use_cache:
      s = self.states_cache.get(self.consumed)
      if not s is None:
        logging.debug("Loaded NMT decoder states from cache for %s" %
                      self.consumed)
        states = copy.deepcopy(s)
        self.decoder_input = states[0]
        self.dec_state = states[1]
        self.word_count = states[2]
        return

    self.decoder_input = [word]
    self.word_count = self.word_count + 1

    if use_cache:
      states = (self.decoder_input, self.dec_state, self.word_count)
      self.states_cache.add(self.consumed, copy.deepcopy(states))

  def get_state(self):
    return (self.decoder_input, self.dec_state, self.word_count, self.consumed)

  def set_state(self, state):
    self.decoder_input, self.dec_state, self.word_count, self.consumed = state

  def is_equal(self, state1, state2):
    """Returns true if the history is the same """
    _, _, _, consumed1 = state1
    _, _, _, consumed2 = state2
    return consumed1 == consumed2
コード例 #7
0
class Word2charPredictor(UnboundedVocabularyPredictor):
    """This predictor wraps word level predictors when SGNMT is running
    on the character level. The mapping between word ID and character 
    ID sequence is loaded from the file system. All characters which
    do not appear in that mapping are treated as word boundary
    makers. The wrapper blocks consume and predict_next calls until a
    word boundary marker is consumed, and updates the slave predictor
    according the word between the last two word boundaries. The 
    mapping is done only on the target side, and the source sentences
    are passed through as they are. To use alternative tokenization on
    the source side, see the altsrc predictor wrapper. The word2char
    wrapper is always an ``UnboundedVocabularyPredictor``.
    """
    
    def __init__(self, map_path, slave_predictor):
        """Creates a new word2char wrapper predictor. The map_path 
        file has to be plain text files, each line containing the 
        mapping from a word index to the character index sequence
        (format: word char1 char2... charn).
        
        Args:
            map_path (string): Path to the mapping file
            slave_predictor (Predictor): Instance of the predictor with
                                         a different wmap than SGNMT
        """
        super(Word2charPredictor, self).__init__()
        self.slave_predictor = slave_predictor
        self.words = SimpleTrie()
        self.word_chars = {}
        with open(map_path) as f:
            for line in f:
                l = [int(x) for x in line.strip().split()]
                word = l[0]
                chars = l[1:]
                self.words.add(chars, word)
                for c in chars:
                    self.word_chars[c] = True   
        if isinstance(slave_predictor, UnboundedVocabularyPredictor): 
            self._get_stub_prob = self._get_stub_prob_unbounded
            self._start_new_word = self._start_new_word_unbounded
        else:
            self._get_stub_prob = self._get_stub_prob_bounded
            self._start_new_word = self._start_new_word_bounded             
    
    def initialize(self, src_sentence):
        """Pass through to slave predictor. The source sentence is not
        modified 
        """
        self.slave_predictor.initialize(src_sentence)
        self._start_new_word()
    
    def initialize_heuristic(self, src_sentence):
        """Pass through to slave predictor. The source sentence is not
        modified 
        """
        self.slave_predictor.initialize_heuristic(src_sentence)
    
    def _update_slave_vars(self, posterior):
        self.slave_unk = self.slave_predictor.get_unk_probability(posterior)
        self.slave_go = common_get(posterior, utils.GO_ID, self.slave_unk)
        self.slave_eos = common_get(posterior, utils.EOS_ID, self.slave_unk)
        
    def _start_new_word_unbounded(self):
        """start_new_word implementation for unbounded vocabulary slave
        predictors. Needs to set slave_go, slave_eos, and slave_unk
        """
        self.word_stub = []
        posterior = self.slave_predictor.predict_next([utils.UNK_ID,
                                                       utils.GO_ID,
                                                       utils.EOS_ID])
        self._update_slave_vars(posterior)
    
    def _start_new_word_bounded(self):
        """start_new_word implementation for bounded vocabulary slave
        predictors. Needs to set slave_go, slave_eos, slave_unk, and
        slave_posterior
        """
        self.word_stub = []
        self.slave_posterior = self.slave_predictor.predict_next()
        self._update_slave_vars(self.slave_posterior)
    
    def _get_stub_prob_unbounded(self):
        """get_stub_prob implementation for unbounded vocabulary slave
        predictors.
        """
        word = self.words.get(self.word_stub)
        if word:
            posterior = self.slave_predictor.predict_next([word])
            return common_get(posterior, word, self.slave_unk)
        return self.slave_unk
    
    def _get_stub_prob_bounded(self):
        """get_stub_prob implementation for bounded vocabulary slave
        predictors.
        """
        word = self.words.get(self.word_stub)
        return common_get(self.slave_posterior,
                          word if word else utils.UNK_ID,
                          self.slave_unk)
    
    def predict_next(self, trgt_words):
        posterior = {}
        stub_prob = False
        for ch in trgt_words:
            if ch in self.word_chars:
                posterior[ch] = 0.0
            else: # Word boundary marker
                if stub_prob is False:
                    stub_prob = self._get_stub_prob() if self.word_stub else 0.0
                posterior[ch] = stub_prob
        if utils.GO_ID in posterior:
            posterior[utils.GO_ID] += self.slave_go
        if utils.EOS_ID in posterior:
            posterior[utils.EOS_ID] += self.slave_eos
        return posterior
        
    def get_unk_probability(self, posterior):
        """This is about the unkown character, not word. Since the word
        level slave predictor has no notion of the unknown character, 
        we return NEG_INF unconditionally.
        """
        return NEG_INF
    
    def consume(self, word):
        """If ``word`` is a word boundary marker, truncate ``word_stub``
        and let the slave predictor consume word_stub. Otherwise, 
        extend ``word_stub`` by the character.
        """
        if word in self.word_chars:
            self.word_stub.append(word)
        elif self.word_stub:
            word = self.words.get(self.word_stub)
            self.slave_predictor.consume(word if word else utils.UNK_ID)
            self._start_new_word()
    
    def get_state(self):
        """Pass through to slave predictor """
        return self.word_stub, self.slave_predictor.get_state()
    
    def set_state(self, state):
        """Pass through to slave predictor """
        self.word_stub, slave_state = state
        self.slave_predictor.set_state(slave_state)

    def reset(self):
        """Pass through to slave predictor """
        self.slave_predictor.reset()

    def estimate_future_cost(self, hypo):
        """Not supported """
        logging.warn("Cannot use future cost estimates of predictors "
                     "wrapped by word2char")
        return 0.0

    def set_current_sen_id(self, cur_sen_id):
        """We need to override this method to propagate current\_
        sentence_id to the slave predictor
        """
        super(Word2charPredictor, self).set_current_sen_id(cur_sen_id)
        self.slave_predictor.set_current_sen_id(cur_sen_id)
    
    def is_equal(self, state1, state2):
        """Pass through to slave predictor """
        stub1, slave_state1 = state1
        stub2, slave_state2 = state2
        return (stub1 == stub2 
                and self.slave_predictor.is_equal(slave_state1, slave_state2))
コード例 #8
0
ファイル: tokenization.py プロジェクト: ucam-smt/sgnmt
class Word2charPredictor(UnboundedVocabularyPredictor):
    """This predictor wraps word level predictors when SGNMT is running
    on the character level. The mapping between word ID and character 
    ID sequence is loaded from the file system. All characters which
    do not appear in that mapping are treated as word boundary
    makers. The wrapper blocks consume and predict_next calls until a
    word boundary marker is consumed, and updates the slave predictor
    according the word between the last two word boundaries. The 
    mapping is done only on the target side, and the source sentences
    are passed through as they are. To use alternative tokenization on
    the source side, see the altsrc predictor wrapper. The word2char
    wrapper is always an ``UnboundedVocabularyPredictor``.
    """
    
    def __init__(self, map_path, slave_predictor):
        """Creates a new word2char wrapper predictor. The map_path 
        file has to be plain text files, each line containing the 
        mapping from a word index to the character index sequence
        (format: word char1 char2... charn).
        
        Args:
            map_path (string): Path to the mapping file
            slave_predictor (Predictor): Instance of the predictor with
                                         a different wmap than SGNMT
        """
        super(Word2charPredictor, self).__init__()
        self.slave_predictor = slave_predictor
        self.words = SimpleTrie()
        self.word_chars = {}
        with open(map_path) as f:
            for line in f:
                l = [int(x) for x in line.strip().split()]
                word = l[0]
                chars = l[1:]
                self.words.add(chars, word)
                for c in chars:
                    self.word_chars[c] = True   
        if isinstance(slave_predictor, UnboundedVocabularyPredictor): 
            self._get_stub_prob = self._get_stub_prob_unbounded
            self._start_new_word = self._start_new_word_unbounded
        else:
            self._get_stub_prob = self._get_stub_prob_bounded
            self._start_new_word = self._start_new_word_bounded             
    
    def initialize(self, src_sentence):
        """Pass through to slave predictor. The source sentence is not
        modified 
        """
        self.slave_predictor.initialize(src_sentence)
        self._start_new_word()
    
    def initialize_heuristic(self, src_sentence):
        """Pass through to slave predictor. The source sentence is not
        modified 
        """
        self.slave_predictor.initialize_heuristic(src_sentence)
    
    def _update_slave_vars(self, posterior):
        self.slave_unk = self.slave_predictor.get_unk_probability(posterior)
        self.slave_go = common_get(posterior, utils.GO_ID, self.slave_unk)
        self.slave_eos = common_get(posterior, utils.EOS_ID, self.slave_unk)
        
    def _start_new_word_unbounded(self):
        """start_new_word implementation for unbounded vocabulary slave
        predictors. Needs to set slave_go, slave_eos, and slave_unk
        """
        self.word_stub = []
        posterior = self.slave_predictor.predict_next([utils.UNK_ID,
                                                       utils.GO_ID,
                                                       utils.EOS_ID])
        self._update_slave_vars(posterior)
    
    def _start_new_word_bounded(self):
        """start_new_word implementation for bounded vocabulary slave
        predictors. Needs to set slave_go, slave_eos, slave_unk, and
        slave_posterior
        """
        self.word_stub = []
        self.slave_posterior = self.slave_predictor.predict_next()
        self._update_slave_vars(self.slave_posterior)
    
    def _get_stub_prob_unbounded(self):
        """get_stub_prob implementation for unbounded vocabulary slave
        predictors.
        """
        word = self.words.get(self.word_stub)
        if word:
            posterior = self.slave_predictor.predict_next([word])
            return common_get(posterior, word, self.slave_unk)
        return self.slave_unk
    
    def _get_stub_prob_bounded(self):
        """get_stub_prob implementation for bounded vocabulary slave
        predictors.
        """
        word = self.words.get(self.word_stub)
        return common_get(self.slave_posterior,
                          word if word else utils.UNK_ID,
                          self.slave_unk)
    
    def predict_next(self, trgt_words):
        posterior = {}
        stub_prob = False
        for ch in trgt_words:
            if ch in self.word_chars:
                posterior[ch] = 0.0
            else: # Word boundary marker
                if stub_prob is False:
                    stub_prob = self._get_stub_prob() if self.word_stub else 0.0
                posterior[ch] = stub_prob
        if utils.GO_ID in posterior:
            posterior[utils.GO_ID] += self.slave_go
        if utils.EOS_ID in posterior:
            posterior[utils.EOS_ID] += self.slave_eos
        return posterior
        
    def get_unk_probability(self, posterior):
        """This is about the unkown character, not word. Since the word
        level slave predictor has no notion of the unknown character, 
        we return NEG_INF unconditionally.
        """
        return NEG_INF
    
    def consume(self, word):
        """If ``word`` is a word boundary marker, truncate ``word_stub``
        and let the slave predictor consume word_stub. Otherwise, 
        extend ``word_stub`` by the character.
        """
        if word in self.word_chars:
            self.word_stub.append(word)
        elif self.word_stub:
            word = self.words.get(self.word_stub)
            self.slave_predictor.consume(word if word else utils.UNK_ID)
            self._start_new_word()
    
    def get_state(self):
        """Pass through to slave predictor """
        return self.word_stub, self.slave_predictor.get_state()
    
    def set_state(self, state):
        """Pass through to slave predictor """
        self.word_stub, slave_state = state
        self.slave_predictor.set_state(slave_state)

    def estimate_future_cost(self, hypo):
        """Not supported """
        logging.warn("Cannot use future cost estimates of predictors "
                     "wrapped by word2char")
        return 0.0

    def set_current_sen_id(self, cur_sen_id):
        """We need to override this method to propagate current\_
        sentence_id to the slave predictor
        """
        super(Word2charPredictor, self).set_current_sen_id(cur_sen_id)
        self.slave_predictor.set_current_sen_id(cur_sen_id)
    
    def is_equal(self, state1, state2):
        """Pass through to slave predictor """
        stub1, slave_state1 = state1
        stub2, slave_state2 = state2
        return (stub1 == stub2 
                and self.slave_predictor.is_equal(slave_state1, slave_state2))
コード例 #9
0
class BagOfWordsPredictor(Predictor):
    """This predictor is similar to the forced predictor, but it does
    not enforce the word order in the reference. Therefore, it assigns
    1 to all hypotheses which have the words in the reference in any 
    order, and -inf to all other hypos.
    """
    def __init__(self,
                 trg_test_file,
                 accept_subsets=False,
                 accept_duplicates=False,
                 heuristic_scores_file="",
                 collect_stats_strategy='best',
                 heuristic_add_consumed=False,
                 heuristic_add_remaining=True,
                 diversity_heuristic_factor=-1.0,
                 equivalence_vocab=-1):
        """Creates a new bag-of-words predictor.
        
        Args:
            trg_test_file (string): Path to the plain text file with 
                                    the target sentences. Must have the
                                    same number of lines as the number
                                    of source sentences to decode. The 
                                    word order in the target sentences
                                    is not relevant for this predictor.
            accept_subsets (bool): If true, this predictor permits
                                       EOS even if the bag is not fully
                                       consumed yet
            accept_duplicates (bool): If true, counts are not updated
                                      when a word is consumed. This
                                      means that we allow a word in a
                                      bag to appear multiple times
            heuristic_scores_file (string): Path to the unigram scores 
                                            which are used if this 
                                            predictor estimates future
                                            costs
            collect_stats_strategy (string): best, full, or all. Defines 
                                             how unigram estimates are 
                                             collected for heuristic 
            heuristic_add_consumed (bool): Set to true to add the 
                                           difference between actual
                                           partial score and unigram
                                           estimates of consumed words
                                           to the predictor heuristic
            heuristic_add_remaining (bool): Set to true to add the sum
                                            of unigram scores of words
                                            remaining in the bag to the
                                            predictor heuristic
            diversity_heuristic_factor (float): Factor for diversity
                                                heuristic which 
                                                penalizes hypotheses
                                                with the same bag as
                                                full hypos
            equivalence_vocab (int): If positive, predictor states are
                                     considered equal if the the 
                                     remaining words within that vocab
                                     and OOVs regarding this vocab are
                                     the same. Only relevant when using
                                     hypothesis recombination
        """
        super(BagOfWordsPredictor, self).__init__()
        with open(trg_test_file) as f:
            self.lines = f.read().splitlines()
        if heuristic_scores_file:
            self.estimates = FileUnigramTable(heuristic_scores_file)
        elif collect_stats_strategy == 'best':
            self.estimates = BestStatsUnigramTable()
        elif collect_stats_strategy == 'full':
            self.estimates = FullStatsUnigramTable()
        elif collect_stats_strategy == 'all':
            self.estimates = AllStatsUnigramTable()
        else:
            logging.error("Unknown statistics collection strategy")
        self.accept_subsets = accept_subsets
        self.accept_duplicates = accept_duplicates
        self.heuristic_add_consumed = heuristic_add_consumed
        self.heuristic_add_remaining = heuristic_add_remaining
        self.equivalence_vocab = equivalence_vocab
        if accept_duplicates and not accept_subsets:
            logging.error("You enabled bow_accept_duplicates but not bow_"
                          "accept_subsets. Therefore, the bow predictor will "
                          "never accept end-of-sentence and could cause "
                          "an infinite loop in the search strategy.")
        self.diversity_heuristic_factor = diversity_heuristic_factor
        self.diverse_heuristic = (diversity_heuristic_factor > 0.0)

    def get_unk_probability(self, posterior):
        """Returns negative infinity unconditionally: Words which are
        not in the target sentence have assigned probability 0 by
        this predictor.
        """
        return NEG_INF

    def predict_next(self):
        """If the bag is empty, the only allowed symbol is EOS. 
        Otherwise, return the list of keys in the bag.
        """
        if not self.bag:  # Empty bag
            return {utils.EOS_ID: 0.0}
        ret = {w: 0.0 for w in self.bag.iterkeys()}
        if self.accept_subsets:
            ret[utils.EOS_ID] = 0.0
        return ret

    def initialize(self, src_sentence):
        """Creates a new bag for the current target sentence..
        
        Args:
            src_sentence (list):  Not used
        """
        self.best_hypo_score = NEG_INF
        self.bag = {}
        for w in self.lines[self.current_sen_id].strip().split():
            int_w = int(w)
            self.bag[int_w] = self.bag.get(int_w, 0) + 1
        self.full_bag = dict(self.bag)

    def consume(self, word):
        """Updates the bag by deleting the consumed word.
        
        Args:
            word (int): Next word to consume
        """
        if word == utils.EOS_ID:
            self.bag = {}
            return
        if not word in self.bag:
            logging.warn("Consuming word which is not in bag-of-words!")
            return
        cnt = self.bag.pop(word)
        if cnt > 1 and not self.accept_duplicates:
            self.bag[word] = cnt - 1

    def get_state(self):
        """State of this predictor is the current bag """
        return self.bag

    def set_state(self, state):
        """State of this predictor is the current bag """
        self.bag = state

    def reset(self):
        """Empty method. """
        pass

    def initialize_heuristic(self, src_sentence):
        """Calls ``reset`` of the used unigram table with estimates
        ``self.estimates`` to clear all statistics from the previous
        sentence
        
        Args:
            src_sentence (list): Not used
        """
        self.estimates.reset()
        if self.diverse_heuristic:
            self.explored_bags = SimpleTrie()

    def notify(self, message, message_type=MESSAGE_TYPE_DEFAULT):
        """This gets called if this predictor observes the decoder. It
        updates unigram heuristic estimates via passing through this
        message to the unigram table ``self.estimates``.
        """
        self.estimates.notify(message, message_type)
        if self.diverse_heuristic and message_type == MESSAGE_TYPE_FULL_HYPO:
            self._update_explored_bags(message)

    def _update_explored_bags(self, hypo):
        """This is called if diversity heuristic is enabled. It updates
        ``self.explored_bags``
        """
        sen = hypo.trgt_sentence
        for l in xrange(len(sen)):
            key = sen[:l]
            key.sort()
            cnt = self.explored_bags.get(key)
            if not cnt:
                cnt = 0.0
            self.explored_bags.add(key, cnt + 1.0)

    def estimate_future_cost(self, hypo):
        """The bow predictor comes with its own heuristic function. We
        use the sum of scores of the remaining words as future cost 
        estimator. 
        """
        acc = 0.0
        if self.heuristic_add_remaining:
            remaining = dict(self.full_bag)
            remaining[utils.EOS_ID] = 1
            for w in hypo.trgt_sentence:
                remaining[w] -= 1
            acc -= sum([
                cnt * self.estimates.estimate(w)
                for w, cnt in remaining.iteritems()
            ])
        if self.diverse_heuristic:
            key = list(hypo.trgt_sentence)
            key.sort()
            cnt = self.explored_bags.get(key)
            if cnt:
                acc += cnt * self.diversity_heuristic_factor
        if self.heuristic_add_consumed:
            acc -= hypo.score - sum([
                self.estimates.estimate(w, -1000.0) for w in hypo.trgt_sentence
            ])
        return acc

    def _get_unk_bag(self, org_bag):
        if self.equivalence_vocab <= 0:
            return org_bag
        unk_bag = {}
        for word, cnt in org_bag.iteritems():
            idx = word if word < self.equivalence_vocab else utils.UNK_ID
            unk_bag[idx] = unk_bag.get(idx, 0) + cnt
        return unk_bag

    def is_equal(self, state1, state2):
        """Returns true if the bag is the same """
        return self._get_unk_bag(state1) == self._get_unk_bag(state2)
コード例 #10
0
ファイル: length.py プロジェクト: ucam-smt/sgnmt
class NgramCountPredictor(Predictor):
    """This predictor counts the number of n-grams in hypotheses. n-gram
    posteriors are loaded from a file. The predictor score is the sum of
    all n-gram posteriors in a hypothesis. """
    
    def __init__(self, path, order=0, discount_factor=-1.0):
        """Creates a new ngram count predictor instance.
        
        Args:
            path (string): Path to the n-gram posteriors. File format:
                           <ngram> : <score> (one ngram per line). Use
                           placeholder %d for sentence id.
            order (int): If positive, count n-grams of the specified
                         order. Otherwise, count all n-grams
            discount_factor (float): If non-negative, discount n-gram
                                     posteriors by this factor each time 
                                     they are consumed 
        """
        super(NgramCountPredictor, self).__init__()
        self.path = path 
        self.order = order
        self.discount_factor = discount_factor
        
    def get_unk_probability(self, posterior):
        """Always return 0.0 """
        return 0.0
    
    def predict_next(self):
        """Composes the posterior vector by collecting all ngrams which
        are consistent with the current history.
        """
        posterior = {}
        for i in reversed(range(len(self.cur_history)+1)):
            scores = self.ngrams.get(self.cur_history[i:])
            if scores:
                factors = False
                if self.discount_factor >= 0.0:
                    factors = self.discounts.get(self.cur_history[i:])
                if not factors:
                    for w,score in scores.iteritems():
                        posterior[w] = posterior.get(w, 0.0) + score
                else:
                    for w,score in scores.iteritems():
                        posterior[w] = posterior.get(w, 0.0) +  \
                                       factors.get(w, 1.0) * score
        return posterior
    
    def _load_posteriors(self, path):
        """Sets up self.max_history_len and self.ngrams """
        self.max_history_len = 0
        self.ngrams = SimpleTrie()
        logging.debug("Loading n-gram scores from %s..." % path)
        with open(path) as f:
            for line in f:
                ngram,score = line.split(':')
                words = [int(w) for w in ngram.strip().split()]
                if self.order > 0 and len(words) != self.order:
                    continue
                hist = words[:-1]
                last_word = words[-1]
                if last_word == utils.GO_ID:
                    continue
                self.max_history_len = max(self.max_history_len, len(hist))
                p = self.ngrams.get(hist)
                if p:
                    p[last_word] = float(score.strip())
                else:
                    self.ngrams.add(hist, {last_word: float(score.strip())})
    
    def initialize(self, src_sentence):
        """Loads n-gram posteriors and resets history.
        
        Args:
            src_sentence (list): not used
        """
        self._load_posteriors(utils.get_path(self.path, self.current_sen_id+1))
        self.cur_history = [utils.GO_ID]
        self.discounts = SimpleTrie()
    
    def consume(self, word):
        """Adds ``word`` to the current history. Shorten if the extended
        history exceeds ``max_history_len``.
        
        Args:
            word (int): Word to add to the history.
        """
        self.cur_history.append(word)
        if len(self.cur_history) > self.max_history_len:
            self.cur_history = self.cur_history[-self.max_history_len:]
        if self.discount_factor >= 0.0:
            for i in range(len(self.cur_history)):
                key = self.cur_history[i:-1]
                factors = self.discounts.get(key)
                if not factors:
                    factors = {word: self.discount_factor}
                else:
                    factors[word] = factors.get(word, 1.0)*self.discount_factor
                self.discounts.add(key, factors)
    
    def get_state(self):
        """Current history is the predictor state """
        return self.cur_history,self.discounts
    
    def set_state(self, state):
        """Current history is the predictor state """
        self.cur_history,self.discounts = state

    def is_equal(self, state1, state2):
        """Hypothesis recombination is
        not supported if discounting is enabled.
        """
        if self.discount_factor >= 0.0:
            return False
        hist1 = state1[0]
        hist2 = state2[0]
        if hist1 == hist2: # Return true if histories match
            return True
        if len(hist1) > len(hist2):
            hist_long = hist1
            hist_short = hist2
        else:
            hist_long = hist2
            hist_short = hist1
        min_len = len(hist_short)
        for n in xrange(1, min_len+1): # Look up non matching in self.ngrams
            key1 = hist1[-n:]
            key2 = hist2[-n:]
            if key1 != key2:
                if self.ngrams.get(key1) or self.ngrams.get(key2):
                    return False
        for n in xrange(min_len+1, len(hist_long)+1):
            if self.ngrams.get(hist_long[-n:]):
                return False
        return True
コード例 #11
0
ファイル: tf_nmt.py プロジェクト: ml-lab/sgnmt
class TensorFlowNMTPredictor(Predictor):
    '''Neural MT predictor'''
    def __init__(self, enable_cache, config, session):
        super(TensorFlowNMTPredictor, self).__init__()
        self.config = config
        self.session = session

        # Add missing entries in config
        if self.config['encoder'] == "bow":
            self.config['init_backward'] = False
            self.config['use_seqlen'] = False
        else:
            self.config['bow_init_const'] = False
            self.config['use_bow_mask'] = False

        # Load tensorflow model
        self.model, self.training_graph, self.encoding_graph, \
          self.single_step_decoding_graph, self.buckets = tf_model_utils.load_model(self.session, config)
        self.model.batch_size = 1  # We decode one sentence at a time.

        self.enc_out = {}
        self.decoder_input = [tf_data_utils.GO_ID]
        self.dec_state = {}
        self.bucket_id = -1
        self.num_heads = 1
        self.word_count = 0

        if config['no_pad_symbol']:
            # This needs to be set in tensorflow data_utils for correct source masks
            tf_data_utils.no_pad_symbol()
            logging.info("UNK_ID=%d" % tf_data_utils.UNK_ID)
            logging.info("PAD_ID=%d" % tf_data_utils.PAD_ID)

        self.enable_cache = enable_cache
        if self.enable_cache:
            logging.info("Cache enabled..")

    def initialize(self, src_sentence):
        # src_sentence is list of integers, without <s> and </s>
        self.reset()
        self.posterior_cache = SimpleTrie()
        self.states_cache = SimpleTrie()

        src_sentence = [
            w if w < self.config['src_vocab_size'] else tf_data_utils.UNK_ID
            for w in src_sentence
        ]
        if self.config['add_src_eos']:
            src_sentence.append(tf_data_utils.EOS_ID)

        feasible_buckets = [
            b for b in xrange(len(self.buckets))
            if self.buckets[b][0] >= len(src_sentence)
        ]
        if not feasible_buckets:
            # Get a new bucket
            bucket = tf_model_utils.make_bucket(len(src_sentence))
            logging.info("Add new bucket={} and update model".format(bucket))
            self.buckets.append(bucket)
            self.model.update_buckets(self.buckets)
            self.bucket_id = len(self.buckets) - 1
        else:
            self.bucket_id = min(feasible_buckets)

        encoder_inputs, _, _, sequence_length, src_mask, bow_mask = self.training_graph.get_batch(
            {self.bucket_id: [(src_sentence, [])]}, self.bucket_id,
            self.config['encoder'])
        logging.info("bucket={}".format(self.buckets[self.bucket_id]))

        last_enc_state, self.enc_out = self.encoding_graph.encode(
            self.session, encoder_inputs, self.bucket_id, sequence_length)

        # Initialize decoder state with last encoder state
        self.dec_state["dec_state"] = last_enc_state
        for a in xrange(self.num_heads):
            self.dec_state["dec_attns_%d" % a] = np.zeros(
                (1, self.enc_out['enc_v_0'].size), dtype=np.float32)

        if self.config['use_src_mask']:
            self.dec_state["src_mask"] = src_mask
            self.src_mask_orig = src_mask.copy()

        if self.config['use_bow_mask']:
            self.dec_state["bow_mask"] = bow_mask
            self.bow_mask_orig = bow_mask.copy()

    def is_history_cachable(self):
        """Returns true if cache is enabled and history contains UNK """
        if not self.enable_cache:
            return False
        for w in self.consumed:
            if w == tf_data_utils.UNK_ID:
                return True
        return False

    def predict_next(self):
        # should return list, numpy array, or dictionary
        if self.decoder_input[0] == tf_data_utils.EOS_ID:  # Predict EOS
            return {tf_data_utils.EOS_ID: 0}

        use_cache = self.is_history_cachable()
        if use_cache:
            posterior = self.posterior_cache.get(self.consumed)
            if not posterior is None:
                logging.debug("Loaded NMT posterior from cache for %s" %
                              self.consumed)
                return posterior

        output, self.dec_state = self.single_step_decoding_graph.decode(
            self.session, self.enc_out, self.dec_state, self.decoder_input,
            self.bucket_id, self.config['use_src_mask'], self.word_count,
            self.config['use_bow_mask'])
        if use_cache:
            self.posterior_cache.add(self.consumed, output[0])

        return output[0]

    def get_unk_probability(self, posterior):
        # posterior is the returned value of the last predict_next call
        return posterior[utils.UNK_ID] if len(posterior) > 1 else float("-inf")

    def consume(self, word):
        if word >= self.config['trg_vocab_size']:
            word = tf_data_utils.UNK_ID  # history is kept according to nmt vocab
        logging.debug("Consume word={}".format(word))
        self.consumed.append(word)

        use_cache = self.is_history_cachable()
        if use_cache:
            s = self.states_cache.get(self.consumed)
            if not s is None:
                logging.debug("Loaded NMT decoder states from cache for %s" %
                              self.consumed)
                states = copy.deepcopy(s)
                self.decoder_input = states[0]
                self.dec_state = states[1]
                self.word_count = states[2]
                return

        self.decoder_input = [word]
        self.word_count = self.word_count + 1

        if use_cache:
            states = (self.decoder_input, self.dec_state, self.word_count)
            self.states_cache.add(self.consumed, copy.deepcopy(states))

    def get_state(self):
        return (self.decoder_input, self.dec_state, self.word_count,
                self.consumed)

    def set_state(self, state):
        self.decoder_input, self.dec_state, self.word_count, self.consumed = state

    def reset(self):
        self.enc_out = {}
        self.decoder_input = [tf_data_utils.GO_ID]
        self.dec_state = {}
        self.word_count = 0
        self.consumed = []

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        _, _, _, consumed1 = state1
        _, _, _, consumed2 = state2
        return consumed1 == consumed2
コード例 #12
0
ファイル: tf_t2t.py プロジェクト: ucam-smt/sgnmt
class EditT2TPredictor(_BaseTensor2TensorPredictor):
    """This predictor can be used for T2T models conditioning on the
    full target sentence. The predictor state is a full target sentence.
    The state can be changed by insertions, substitutions, and deletions
    of single tokens, whereas each operation is encoded as SGNMT token
    in the following way:

      1xxxyyyyy: Insert the token yyyyy at position xxx.
      2xxxyyyyy: Replace the xxx-th word with the token yyyyy.
      3xxx00000: Delete the xxx-th token.
    """

    INS_OFFSET = 100000000
    SUB_OFFSET = 200000000
    DEL_OFFSET = 300000000

    POS_FACTOR = 100000
    MAX_SEQ_LEN = 999

    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 model_name,
                 problem_name,
                 hparams_set_name,
                 trg_test_file,
                 beam_size,
                 t2t_usr_dir,
                 checkpoint_dir,
                 t2t_unk_id=None,
                 n_cpu_threads=-1,
                 max_terminal_id=-1,
                 pop_id=-1):
        """Creates a new edit T2T predictor. This constructor is
        similar to the constructor of T2TPredictor but creates a
        different computation graph which retrieves scores at each
        target position, not only the last one.
        
        Args:
            src_vocab_size (int): Source vocabulary size.
            trg_vocab_size (int): Target vocabulary size.
            model_name (string): T2T model name.
            problem_name (string): T2T problem name.
            hparams_set_name (string): T2T hparams set name.
            trg_test_file (string): Path to a plain text file with
                initial target sentences. Can be empty.
            beam_size (int): Determines how many substitutions and
                insertions are considered at each position.
            t2t_usr_dir (string): See --t2t_usr_dir in tensor2tensor.
            checkpoint_dir (string): Path to the T2T checkpoint 
                                     directory. The predictor will load
                                     the top most checkpoint in the 
                                     `checkpoints` file.
            t2t_unk_id (int): If set, use this ID to get UNK scores. If
                              None, UNK is always scored with -inf.
            n_cpu_threads (int): Number of TensorFlow CPU threads.
            max_terminal_id (int): If positive, maximum terminal ID. Needs to
                be set for syntax-based T2T models.
            pop_id (int): If positive, ID of the POP or closing bracket symbol.
                Needs to be set for syntax-based T2T models.
        """
        super(EditT2TPredictor, self).__init__(t2t_usr_dir, 
                                               checkpoint_dir, 
                                               src_vocab_size,
                                               trg_vocab_size,
                                               t2t_unk_id, 
                                               n_cpu_threads,
                                               max_terminal_id,
                                               pop_id)
        if not model_name or not problem_name or not hparams_set_name:
            logging.fatal(
                "Please specify t2t_model, t2t_problem, and t2t_hparams_set!")
            raise AttributeError
        if trg_vocab_size >= EditT2TPredictor.POS_FACTOR:
            logging.fatal("Target vocabulary size (%d) must be less than %d!"
                          % (trg_vocab_size, EditT2TPredictor.POS_FACTOR))
            raise AttributeError
        self.beam_size = max(1, beam_size // 10) + 1
        self.batch_size = 2048 # TODO(fstahlberg): Move to config
        self.initial_trg_sentences = None
        if trg_test_file: 
            self.initial_trg_sentences = []
            with open(trg_test_file) as f:
                for line in f:
                    self.initial_trg_sentences.append(utils.oov_to_unk(
                       [int(w) for w in line.strip().split()] + [utils.EOS_ID],
                       self.trg_vocab_size, self._t2t_unk_id))
        predictor_graph = tf.Graph()
        with predictor_graph.as_default() as g:
            hparams = trainer_lib.create_hparams(hparams_set_name)
            self._add_problem_hparams(hparams, problem_name)
            translate_model = registry.model(model_name)(
                hparams, tf.estimator.ModeKeys.EVAL)
            self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None],
                                              name="sgnmt_inputs")
            self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None, None], 
                                               name="sgnmt_targets")
            shp = tf.shape(self._targets_var)
            bsz = shp[0]
            inputs = tf.tile(tf.expand_dims(self._inputs_var, 0), [bsz, 1])
            features = {"inputs": expand_input_dims_for_t2t(inputs,
                                                            batched=True), 
                        "targets": expand_input_dims_for_t2t(self._targets_var,
                                                             batched=True)}
            translate_model.prepare_features_for_infer(features)
            translate_model._fill_problem_hparams_features(features)
            logits, _ = translate_model(features)
            logits = tf.squeeze(logits, [2, 3])
            self._log_probs = log_prob_from_logits(logits)
            diag_logits = gather_2d(logits, tf.expand_dims(tf.range(bsz), 1))
            self._diag_log_probs = log_prob_from_logits(diag_logits)
            no_pad = tf.cast(tf.not_equal(
                self._targets_var, text_encoder.PAD_ID), tf.float32)
            flat_bsz = shp[0] * shp[1]
            word_scores = gather_2d(
                tf.reshape(self._log_probs, [flat_bsz, -1]),
                tf.reshape(self._targets_var, [flat_bsz, 1]))
            word_scores = tf.reshape(word_scores, (shp[0], shp[1])) * no_pad
            self._sentence_scores = tf.reduce_sum(word_scores, -1)
            self.mon_sess = self.create_session()

    def _ins_op(self, pos, token):
        """Returns a copy of trg sentence after an insertion."""
        return self.trg_sentence[:pos] + [token] + self.trg_sentence[pos:]

    def _sub_op(self, pos, token):
        """Returns a copy of trg sentence after a substitution."""
        ret = list(self.trg_sentence)
        ret[pos] = token
        return ret

    def _del_op(self, pos):
        """Returns a copy of trg sentence after a deletion."""
        return self.trg_sentence[:pos] + self.trg_sentence[pos+1:]

    def _top_n(self, scores, sort=False):
        """Sorted indices of beam_size best entries along axis 1"""
        costs = -scores
        costs[:, utils.EOS_ID] = utils.INF
        top_n_indices = np.argpartition(
            costs,
            self.beam_size, 
            axis=1)[:, :self.beam_size]
        if not sort:
            return top_n_indices
        b_indices = np.expand_dims(np.arange(top_n_indices.shape[0]), axis=1)
        sorted_indices = np.argsort(costs[b_indices, top_n_indices], axis=1)
        return top_n_indices[b_indices, sorted_indices]

    def predict_next(self):
        """Call the T2T model in self.mon_sess."""
        next_sentences = {}
        logging.debug("EditT2T: Exploring score=%f sentence=%s" 
                      % (self.cur_score, " ".join(map(str, self.trg_sentence))))
        n_trg_words = len(self.trg_sentence)
        if n_trg_words > EditT2TPredictor.MAX_SEQ_LEN:
            logging.warn("EditT2T: Target sentence exceeds maximum length (%d)"
                         % EDITT2TPredictor.MAX_SEQ_LEN)
            return {utils.EOS_ID: 0.0}
        # Substitutions
        log_probs = self.mon_sess.run(self._log_probs,
            {self._inputs_var: self.src_sentence,
             self._targets_var: [self.trg_sentence]})
        top_n = self._top_n(np.squeeze(log_probs, axis=0))
        for pos, cur_token in enumerate(self.trg_sentence[:-1]):
            offset = EditT2TPredictor.SUB_OFFSET 
            offset += EditT2TPredictor.POS_FACTOR * pos
            for token in top_n[pos]:
                if token != cur_token:
                    next_sentences[offset + token] = self._sub_op(pos, token)

        # Insertions
        if n_trg_words < EditT2TPredictor.MAX_SEQ_LEN - 1:
            ins_trg_sentences = np.full((n_trg_words, n_trg_words+1), 999)
            for pos in xrange(n_trg_words):
                ins_trg_sentences[pos, :pos] = self.trg_sentence[:pos]
                ins_trg_sentences[pos, pos+1:] = self.trg_sentence[pos:]
            diag_log_probs = self.mon_sess.run(self._diag_log_probs,
                {self._inputs_var: self.src_sentence,
                 self._targets_var: ins_trg_sentences})
            top_n = self._top_n(np.squeeze(diag_log_probs, axis=1))
            for pos in xrange(n_trg_words):
                offset = EditT2TPredictor.INS_OFFSET 
                offset += EditT2TPredictor.POS_FACTOR * pos
                for token in top_n[pos]:
                    next_sentences[offset + token] = self._ins_op(pos, token)
        # Deletions
        idx = EditT2TPredictor.DEL_OFFSET
        for pos in xrange(n_trg_words - 1): # -1: Do not delete EOS
            next_sentences[idx] = self._del_op(pos)
            idx += EditT2TPredictor.POS_FACTOR
        abs_scores = self._score(next_sentences, n_trg_words + 1)
        rel_scores = {i: s - self.cur_score 
                      for i, s in abs_scores.iteritems()}
        rel_scores[utils.EOS_ID] = 0.0
        return rel_scores

    def _score(self, sentences, n_trg_words=1):
        max_n_sens = max(1, self.batch_size // n_trg_words)
        scores = {}
        batch_ids = []
        batch_sens = []
        for idx, trg_sentence in sentences.iteritems():
            score = self.cache.get(trg_sentence)
            if score is None:
                batch_ids.append(idx)
                np_sen = np.zeros(n_trg_words, dtype=np.int)
                np_sen[:len(trg_sentence)] = trg_sentence
                batch_sens.append(np_sen)
                if len(batch_ids) >= max_n_sens:
                    self._score_single_batch(scores, batch_ids, batch_sens)
                    batch_ids = []
                    batch_sens = []
            else:
                scores[idx] = score
        self._score_single_batch(scores, batch_ids, batch_sens)
        return scores 

    def _score_single_batch(self, scores, ids, trg_sentences):
        "Score sentences and add them to scores and the cache."""
        if not ids:
            return
        batch_scores = self.mon_sess.run(self._sentence_scores,
            {self._inputs_var: self.src_sentence,
             self._targets_var: np.stack(trg_sentences)})
        for idx, sen, score in zip(ids, trg_sentences, batch_scores):
            self.cache.add(sen, score)
            scores[idx] = score

    def _update_cur_score(self):
        self.cur_score = self.cache.get(self.trg_sentence)
        if self.cur_score is None:
            scores = self._score({1: self.trg_sentence}, len(self.trg_sentence))
            self.cur_score = scores[1]
            self.cache.add(self.trg_sentence, self.cur_score)
    
    def initialize(self, src_sentence):
        """Set src_sentence, reset consumed."""
        if self.initial_trg_sentences is None:
            self.trg_sentence = [text_encoder.EOS_ID]
        else:
            self.trg_sentence = self.initial_trg_sentences[self.current_sen_id]
        self.src_sentence = utils.oov_to_unk(
            src_sentence + [text_encoder.EOS_ID], 
            self.src_vocab_size, self._t2t_unk_id)
        self.cache = SimpleTrie()
        self._update_cur_score()
        logging.debug("Initial score: %f" % self.cur_score)
   
    def consume(self, word):
        """Append ``word`` to the current history."""
        if word == utils.EOS_ID:
            return
        pos = (word // EditT2TPredictor.POS_FACTOR) \
              % (EditT2TPredictor.MAX_SEQ_LEN + 1)
        token = word % EditT2TPredictor.POS_FACTOR
        # TODO(fstahlberg): Do not hard code the following section
        op = word // 100000000  
        if op == 1:  # Insertion
            self.trg_sentence = self._ins_op(pos, token)
        elif op == 2:  # Substitution
            self.trg_sentence = self._sub_op(pos, token)
        elif op == 3:  # Deletion
            self.trg_sentence = self._del_op(pos)
        else:
            logging.warn("Invalid edit descriptor %d. Ignoring..." % word)
        self._update_cur_score()
        self.cache.add(self.trg_sentence, utils.NEG_INF)
    
    def get_state(self):
        """The predictor state is the complete target sentence."""
        return self.trg_sentence, self.cur_score
    
    def set_state(self, state):
        """The predictor state is the complete target sentence."""
        self.trg_sentence, self.cur_score = state

    def is_equal(self, state1, state2):
        """Returns true if the target sentence is the same """
        return state1[0] == state2[0]