class GreedyHeuristic(Heuristic): """This heuristic performs greedy decoding to get future cost estimates. This is expensive but can lead to very close estimates. """ def __init__(self, decoder_args, cache_estimates=True): """Creates a new ``GreedyHeuristic`` instance. The greedy heuristic performs full greedy decoding from the current state to get accurate cost estimates. However, this can be very expensive. Args: decoder_args (object): Decoder configuration passed through from the configuration API. cache_estimates (bool): Set to true to enable a cache for predictor states which have been visited during the greedy decoding. """ super(GreedyHeuristic, self).__init__() self.cache_estimates = cache_estimates self.decoder = GreedyDecoder(decoder_args) self.cache = SimpleTrie() def set_predictors(self, predictors): """Override ``Decoder.set_predictors`` to redirect the predictors to ``self.decoder`` """ self.predictors = predictors self.decoder.predictors = predictors def initialize(self, src_sentence): """Initialize the cache. """ self.cache = SimpleTrie() def estimate_future_cost(self, hypo): """Estimate the future cost by full greedy decoding. If ``self.cache_estimates`` is enabled, check cache first """ if self.cache_estimates: return self.estimate_future_cost_with_cache(hypo) else: return self.estimate_future_cost_without_cache(hypo) def estimate_future_cost_with_cache(self, hypo): """Enabled cache... """ cached_cost = self.cache.get(hypo.trgt_sentence) if not cached_cost is None: return cached_cost old_states = self.decoder.get_predictor_states() self.decoder.set_predictor_states(copy.deepcopy(old_states)) # Greedy decoding trgt_word = hypo.trgt_sentence[-1] scores = [] words = [] while trgt_word != utils.EOS_ID: self.decoder.consume(trgt_word) posterior, _ = self.decoder.apply_predictors() trgt_word = utils.argmax(posterior) scores.append(posterior[trgt_word]) words.append(trgt_word) # Update cache using scores and words for i in xrange(1, len(scores)): self.cache.add(hypo.trgt_sentence + words[:i], -sum(scores[i:])) # Reset predictor states self.decoder.set_predictor_states(old_states) return -sum(scores) def estimate_future_cost_without_cache(self, hypo): """Disabled cache... """ old_states = self.decoder.get_predictor_states() self.decoder.set_predictor_states(copy.deepcopy(old_states)) # Greedy decoding trgt_word = hypo.trgt_sentence[-1] score = 0.0 while trgt_word != utils.EOS_ID: self.decoder.consume(trgt_word) posterior, _ = self.decoder.apply_predictors() trgt_word = utils.argmax(posterior) score += posterior[trgt_word] # Reset predictor states self.decoder.set_predictor_states(old_states) return -score
class BagOfWordsPredictor(Predictor): """This predictor is similar to the forced predictor, but it does not enforce the word order in the reference. Therefore, it assigns 1 to all hypotheses which have the words in the reference in any order, and -inf to all other hypos. """ def __init__(self, trg_test_file, accept_subsets=False, accept_duplicates=False, heuristic_scores_file="", collect_stats_strategy='best', heuristic_add_consumed = False, heuristic_add_remaining = True, diversity_heuristic_factor = -1.0, equivalence_vocab=-1): """Creates a new bag-of-words predictor. Args: trg_test_file (string): Path to the plain text file with the target sentences. Must have the same number of lines as the number of source sentences to decode. The word order in the target sentences is not relevant for this predictor. accept_subsets (bool): If true, this predictor permits EOS even if the bag is not fully consumed yet accept_duplicates (bool): If true, counts are not updated when a word is consumed. This means that we allow a word in a bag to appear multiple times heuristic_scores_file (string): Path to the unigram scores which are used if this predictor estimates future costs collect_stats_strategy (string): best, full, or all. Defines how unigram estimates are collected for heuristic heuristic_add_consumed (bool): Set to true to add the difference between actual partial score and unigram estimates of consumed words to the predictor heuristic heuristic_add_remaining (bool): Set to true to add the sum of unigram scores of words remaining in the bag to the predictor heuristic diversity_heuristic_factor (float): Factor for diversity heuristic which penalizes hypotheses with the same bag as full hypos equivalence_vocab (int): If positive, predictor states are considered equal if the the remaining words within that vocab and OOVs regarding this vocab are the same. Only relevant when using hypothesis recombination """ super(BagOfWordsPredictor, self).__init__() with open(trg_test_file) as f: self.lines = f.read().splitlines() if heuristic_scores_file: self.estimates = FileUnigramTable(heuristic_scores_file) elif collect_stats_strategy == 'best': self.estimates = BestStatsUnigramTable() elif collect_stats_strategy == 'full': self.estimates = FullStatsUnigramTable() elif collect_stats_strategy == 'all': self.estimates = AllStatsUnigramTable() else: logging.error("Unknown statistics collection strategy") self.accept_subsets = accept_subsets self.accept_duplicates = accept_duplicates self.heuristic_add_consumed = heuristic_add_consumed self.heuristic_add_remaining = heuristic_add_remaining self.equivalence_vocab = equivalence_vocab if accept_duplicates and not accept_subsets: logging.error("You enabled bow_accept_duplicates but not bow_" "accept_subsets. Therefore, the bow predictor will " "never accept end-of-sentence and could cause " "an infinite loop in the search strategy.") self.diversity_heuristic_factor = diversity_heuristic_factor self.diverse_heuristic = (diversity_heuristic_factor > 0.0) def get_unk_probability(self, posterior): """Returns negative infinity unconditionally: Words which are not in the target sentence have assigned probability 0 by this predictor. """ return NEG_INF def predict_next(self): """If the bag is empty, the only allowed symbol is EOS. Otherwise, return the list of keys in the bag. """ if not self.bag: # Empty bag return {utils.EOS_ID : 0.0} ret = {w : 0.0 for w in self.bag.iterkeys()} if self.accept_subsets: ret[utils.EOS_ID] = 0.0 return ret def initialize(self, src_sentence): """Creates a new bag for the current target sentence.. Args: src_sentence (list): Not used """ self.best_hypo_score = NEG_INF self.bag = {} for w in self.lines[self.current_sen_id].strip().split(): int_w = int(w) self.bag[int_w] = self.bag.get(int_w, 0) + 1 self.full_bag = dict(self.bag) def consume(self, word): """Updates the bag by deleting the consumed word. Args: word (int): Next word to consume """ if word == utils.EOS_ID: self.bag = {} return if not word in self.bag: logging.warn("Consuming word which is not in bag-of-words!") return cnt = self.bag.pop(word) if cnt > 1 and not self.accept_duplicates: self.bag[word] = cnt - 1 def get_state(self): """State of this predictor is the current bag """ return self.bag def set_state(self, state): """State of this predictor is the current bag """ self.bag = state def initialize_heuristic(self, src_sentence): """Calls ``reset`` of the used unigram table with estimates ``self.estimates`` to clear all statistics from the previous sentence Args: src_sentence (list): Not used """ self.estimates.reset() if self.diverse_heuristic: self.explored_bags = SimpleTrie() def notify(self, message, message_type = MESSAGE_TYPE_DEFAULT): """This gets called if this predictor observes the decoder. It updates unigram heuristic estimates via passing through this message to the unigram table ``self.estimates``. """ self.estimates.notify(message, message_type) if self.diverse_heuristic and message_type == MESSAGE_TYPE_FULL_HYPO: self._update_explored_bags(message) def _update_explored_bags(self, hypo): """This is called if diversity heuristic is enabled. It updates ``self.explored_bags`` """ sen = hypo.trgt_sentence for l in xrange(len(sen)): key = sen[:l] key.sort() cnt = self.explored_bags.get(key) if not cnt: cnt = 0.0 self.explored_bags.add(key, cnt + 1.0) def estimate_future_cost(self, hypo): """The bow predictor comes with its own heuristic function. We use the sum of scores of the remaining words as future cost estimator. """ acc = 0.0 if self.heuristic_add_remaining: remaining = dict(self.full_bag) remaining[utils.EOS_ID] = 1 for w in hypo.trgt_sentence: remaining[w] -= 1 acc -= sum([cnt*self.estimates.estimate(w) for w,cnt in remaining.iteritems()]) if self.diverse_heuristic: key = list(hypo.trgt_sentence) key.sort() cnt = self.explored_bags.get(key) if cnt: acc += cnt * self.diversity_heuristic_factor if self.heuristic_add_consumed: acc -= hypo.score - sum([self.estimates.estimate(w, -1000.0) for w in hypo.trgt_sentence]) return acc def _get_unk_bag(self, org_bag): if self.equivalence_vocab <= 0: return org_bag unk_bag = {} for word,cnt in org_bag.iteritems(): idx = word if word < self.equivalence_vocab else utils.UNK_ID unk_bag[idx] = unk_bag.get(idx, 0) + cnt return unk_bag def is_equal(self, state1, state2): """Returns true if the bag is the same """ return self._get_unk_bag(state1) == self._get_unk_bag(state2)
class BlocksNMTPredictor(Predictor): """This is the neural machine translation predictor. The predicted posteriors are equal to the distribution generated by the decoder network in NMT. This predictor heavily relies on the NMT example in blocks. Note that this predictor cannot be used in combination with a target side sparse feature map. See ``BlocksUnboundedNMTPredictor`` for that case. """ def __init__(self, nmt_model_path, gnmt_beta, enable_cache, config): """Creates a new NMT predictor. Args: nmt_model_path (string): Path to the NMT model file (.npz) gnmt_beta (float): If greater than 0.0, add a Google NMT style coverage penalization term (Wu et al., 2016) to the predictive scores enable_cache (bool): The NMT predictor usually has a very limited vocabulary size, and a large number of UNKs in hypotheses. This enables reusing already computed predictor states for hypotheses which differ only by NMT OOV words. config (dict): NMT configuration Raises: ValueError. If a target sparse feature map is defined """ super(BlocksNMTPredictor, self).__init__() self.gnmt_beta = gnmt_beta self.add_gnmt_coverage_term = gnmt_beta > 0.0 self.config = copy.deepcopy(config) self.enable_cache = enable_cache self.set_up_predictor(nmt_model_path) self.src_eos = self.src_sparse_feat_map.word2dense(utils.EOS_ID) def set_up_predictor(self, nmt_model_path): """Initializes the predictor with the given NMT model. Code following ``blocks.machine_translation.main``. """ self.src_vocab_size = self.config['src_vocab_size'] self.trgt_vocab_size = self.config['trg_vocab_size'] self.nmt_model = NMTModel(self.config) self.nmt_model.set_up() loader = LoadNMTUtils(nmt_model_path, self.config['saveto'], self.nmt_model.search_model) loader.load_weights() self.best_models = [] self.val_bleu_curve = [] self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \ if self.config['src_sparse_feat_map'] else FlatSparseFeatMap() if self.config['trg_sparse_feat_map']: logging.fatal("Cannot use bounded vocabulary predictor with " "a target sparse feature map. Ignoring...") self.search_algorithm = MyopticSearch(samples=self.nmt_model.samples) self.search_algorithm.compile() def initialize(self, src_sentence): """Runs the encoder network to create the source annotations for the source sentence. If the cache is enabled, empty the cache. Args: src_sentence (list): List of word ids without <S> and </S> which represent the source sentence. """ self.contexts = None self.states = None self.posterior_cache = SimpleTrie() self.states_cache = SimpleTrie() self.consumed = [] seq = self.src_sparse_feat_map.words2dense( utils.oov_to_unk(src_sentence, self.src_vocab_size)) + [self.src_eos] if self.src_sparse_feat_map.dim > 1: # sparse src feats input_ = np.transpose(np.tile(seq, (1, 1, 1)), (2,0,1)) else: # word ids on the source side input_ = np.tile(seq, (1, 1)) input_values={self.nmt_model.sampling_input: input_} self.contexts, self.states, _ = self.search_algorithm.compute_initial_states_and_contexts( input_values) self.attention_records = (1 + len(src_sentence)) * [0.0] def is_history_cachable(self): """Returns true if cache is enabled and history contains UNK """ if not self.enable_cache: return False for w in self.consumed: if w == utils.UNK_ID: return True return False def predict_next(self): """Uses cache or runs the decoder network to get the distribution over the next target words. Returns: np array. Full distribution over the entire NMT vocabulary for the next target token. """ use_cache = self.is_history_cachable() if use_cache: posterior = self.posterior_cache.get(self.consumed) if not posterior is None: logging.debug("Loaded NMT posterior from cache for %s" % self.consumed) return self._add_gnmt_beta(posterior) # logprobs are negative log probs, i.e. greater than 0 logprobs = self.search_algorithm.compute_logprobs(self.contexts, self.states) posterior = np.multiply(logprobs[0], -1.0) if use_cache: self.posterior_cache.add(self.consumed, posterior) return self._add_gnmt_beta(posterior) def _add_gnmt_beta(self, posterior): """Adds the GNMT coverage penalization term to EOS in ``posterior`` """ if self.add_gnmt_coverage_term: posterior[utils.EOS_ID] += self.gnmt_beta * sum([np.log(max(0.0001, p)) for p in self.attention_records if p < 1.0]) return posterior def get_unk_probability(self, posterior): """Returns the UNK probability defined by NMT. """ return posterior[utils.UNK_ID] if len(posterior) > utils.UNK_ID else NEG_INF def consume(self, word): """Feeds back ``word`` to the decoder network. This includes embedding of ``word``, running the attention network and update the recurrent decoder layer. """ if word >= self.trgt_vocab_size: word = utils.UNK_ID self.consumed.append(word) use_cache = self.is_history_cachable() if use_cache: s = self.states_cache.get(self.consumed) if not s is None: logging.debug("Loaded NMT decoder states from cache for %s" % self.consumed) self.states = copy.deepcopy(s) return self.states.update(self.search_algorithm.compute_next_states( self.contexts, self.states, [word])) if use_cache: self.states_cache.add(self.consumed, copy.deepcopy(self.states)) if self.add_gnmt_coverage_term: # Keep track of attentions for pos,att in enumerate(self.states['weights'][0]): self.attention_records[pos] += att def get_state(self): """The NMT predictor state consists of the decoder network state, and (for caching) the current history of consumed words """ return self.states,self.consumed,self.attention_records def set_state(self, state): """Set the NMT predictor state. """ self.states,self.consumed,self.attention_records = state def is_equal(self, state1, state2): """Returns true if the history is the same """ _,consumed1,_ = state1 _,consumed2,_ = state2 return consumed1 == consumed2
class NgramCountPredictor(Predictor): """This predictor counts the number of n-grams in hypotheses. n-gram posteriors are loaded from a file. The predictor score is the sum of all n-gram posteriors in a hypothesis. """ def __init__(self, path, order=0, discount_factor=-1.0): """Creates a new ngram count predictor instance. Args: path (string): Path to the n-gram posteriors. File format: <ngram> : <score> (one ngram per line). Use placeholder %d for sentence id. order (int): If positive, count n-grams of the specified order. Otherwise, count all n-grams discount_factor (float): If non-negative, discount n-gram posteriors by this factor each time they are consumed """ super(NgramCountPredictor, self).__init__() self.path = path self.order = order self.discount_factor = discount_factor def get_unk_probability(self, posterior): """Always return 0.0 """ return 0.0 def predict_next(self): """Composes the posterior vector by collecting all ngrams which are consistent with the current history. """ posterior = {} for i in reversed(range(len(self.cur_history) + 1)): scores = self.ngrams.get(self.cur_history[i:]) if scores: factors = False if self.discount_factor >= 0.0: factors = self.discounts.get(self.cur_history[i:]) if not factors: for w, score in scores.iteritems(): posterior[w] = posterior.get(w, 0.0) + score else: for w, score in scores.iteritems(): posterior[w] = posterior.get(w, 0.0) + \ factors.get(w, 1.0) * score return posterior def _load_posteriors(self, path): """Sets up self.max_history_len and self.ngrams """ self.max_history_len = 0 self.ngrams = SimpleTrie() with open(path) as f: for line in f: ngram, score = line.split(':') words = [int(w) for w in ngram.strip().split()] if self.order > 0 and len(words) != self.order: continue hist = words[:-1] last_word = words[-1] if last_word == utils.GO_ID: continue self.max_history_len = max(self.max_history_len, len(hist)) p = self.ngrams.get(hist) if p: p[last_word] = float(score.strip()) else: self.ngrams.add(hist, {last_word: float(score.strip())}) def initialize(self, src_sentence): """Loads n-gram posteriors and resets history. Args: src_sentence (list): not used """ self._load_posteriors( utils.get_path(self.path, self.current_sen_id + 1)) self.cur_history = [utils.GO_ID] self.discounts = SimpleTrie() def consume(self, word): """Adds ``word`` to the current history. Shorten if the extended history exceeds ``max_history_len``. Args: word (int): Word to add to the history. """ self.cur_history.append(word) if len(self.cur_history) > self.max_history_len: self.cur_history = self.cur_history[-self.max_history_len:] if self.discount_factor >= 0.0: for i in range(len(self.cur_history)): key = self.cur_history[i:-1] factors = self.discounts.get(key) if not factors: factors = {word: self.discount_factor} else: factors[word] = factors.get(word, 1.0) * self.discount_factor self.discounts.add(key, factors) def get_state(self): """Current history is the predictor state """ return self.cur_history, self.discounts def set_state(self, state): """Current history is the predictor state """ self.cur_history, self.discounts = state def reset(self): """Empty method. """ pass def is_equal(self, state1, state2): """Hypothesis recombination is not supported if discounting is enabled. """ if self.discount_factor >= 0.0: return False hist1 = state1[0] hist2 = state2[0] if hist1 == hist2: # Return true if histories match return True if len(hist1) > len(hist2): hist_long = hist1 hist_short = hist2 else: hist_long = hist2 hist_short = hist1 min_len = len(hist_short) for n in xrange(1, min_len + 1): # Look up non matching in self.ngrams key1 = hist1[-n:] key2 = hist2[-n:] if key1 != key2: if self.ngrams.get(key1) or self.ngrams.get(key2): return False for n in xrange(min_len + 1, len(hist_long) + 1): if self.ngrams.get(hist_long[-n:]): return False return True
class EditT2TPredictor(_BaseTensor2TensorPredictor): """This predictor can be used for T2T models conditioning on the full target sentence. The predictor state is a full target sentence. The state can be changed by insertions, substitutions, and deletions of single tokens, whereas each operation is encoded as SGNMT token in the following way: 1xxxyyyyy: Insert the token yyyyy at position xxx. 2xxxyyyyy: Replace the xxx-th word with the token yyyyy. 3xxx00000: Delete the xxx-th token. """ INS_OFFSET = 100000000 SUB_OFFSET = 200000000 DEL_OFFSET = 300000000 POS_FACTOR = 100000 MAX_SEQ_LEN = 999 def __init__(self, src_vocab_size, trg_vocab_size, model_name, problem_name, hparams_set_name, trg_test_file, beam_size, t2t_usr_dir, checkpoint_dir, t2t_unk_id=None, n_cpu_threads=-1, max_terminal_id=-1, pop_id=-1): """Creates a new edit T2T predictor. This constructor is similar to the constructor of T2TPredictor but creates a different computation graph which retrieves scores at each target position, not only the last one. Args: src_vocab_size (int): Source vocabulary size. trg_vocab_size (int): Target vocabulary size. model_name (string): T2T model name. problem_name (string): T2T problem name. hparams_set_name (string): T2T hparams set name. trg_test_file (string): Path to a plain text file with initial target sentences. Can be empty. beam_size (int): Determines how many substitutions and insertions are considered at each position. t2t_usr_dir (string): See --t2t_usr_dir in tensor2tensor. checkpoint_dir (string): Path to the T2T checkpoint directory. The predictor will load the top most checkpoint in the `checkpoints` file. t2t_unk_id (int): If set, use this ID to get UNK scores. If None, UNK is always scored with -inf. n_cpu_threads (int): Number of TensorFlow CPU threads. max_terminal_id (int): If positive, maximum terminal ID. Needs to be set for syntax-based T2T models. pop_id (int): If positive, ID of the POP or closing bracket symbol. Needs to be set for syntax-based T2T models. """ super(EditT2TPredictor, self).__init__(t2t_usr_dir, checkpoint_dir, src_vocab_size, trg_vocab_size, t2t_unk_id, n_cpu_threads, max_terminal_id, pop_id) if not model_name or not problem_name or not hparams_set_name: logging.fatal( "Please specify t2t_model, t2t_problem, and t2t_hparams_set!") raise AttributeError if trg_vocab_size >= EditT2TPredictor.POS_FACTOR: logging.fatal("Target vocabulary size (%d) must be less than %d!" % (trg_vocab_size, EditT2TPredictor.POS_FACTOR)) raise AttributeError self.beam_size = max(1, beam_size // 10) + 1 self.batch_size = 2048 # TODO(fstahlberg): Move to config self.initial_trg_sentences = None if trg_test_file: self.initial_trg_sentences = [] with open(trg_test_file) as f: for line in f: self.initial_trg_sentences.append(utils.oov_to_unk( [int(w) for w in line.strip().split()] + [utils.EOS_ID], self.trg_vocab_size, self._t2t_unk_id)) predictor_graph = tf.Graph() with predictor_graph.as_default() as g: hparams = trainer_lib.create_hparams(hparams_set_name) self._add_problem_hparams(hparams, problem_name) translate_model = registry.model(model_name)( hparams, tf.estimator.ModeKeys.EVAL) self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None], name="sgnmt_inputs") self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None, None], name="sgnmt_targets") shp = tf.shape(self._targets_var) bsz = shp[0] inputs = tf.tile(tf.expand_dims(self._inputs_var, 0), [bsz, 1]) features = {"inputs": expand_input_dims_for_t2t(inputs, batched=True), "targets": expand_input_dims_for_t2t(self._targets_var, batched=True)} translate_model.prepare_features_for_infer(features) translate_model._fill_problem_hparams_features(features) logits, _ = translate_model(features) logits = tf.squeeze(logits, [2, 3]) self._log_probs = log_prob_from_logits(logits) diag_logits = gather_2d(logits, tf.expand_dims(tf.range(bsz), 1)) self._diag_log_probs = log_prob_from_logits(diag_logits) no_pad = tf.cast(tf.not_equal( self._targets_var, text_encoder.PAD_ID), tf.float32) flat_bsz = shp[0] * shp[1] word_scores = gather_2d( tf.reshape(self._log_probs, [flat_bsz, -1]), tf.reshape(self._targets_var, [flat_bsz, 1])) word_scores = tf.reshape(word_scores, (shp[0], shp[1])) * no_pad self._sentence_scores = tf.reduce_sum(word_scores, -1) self.mon_sess = self.create_session() def _ins_op(self, pos, token): """Returns a copy of trg sentence after an insertion.""" return self.trg_sentence[:pos] + [token] + self.trg_sentence[pos:] def _sub_op(self, pos, token): """Returns a copy of trg sentence after a substitution.""" ret = list(self.trg_sentence) ret[pos] = token return ret def _del_op(self, pos): """Returns a copy of trg sentence after a deletion.""" return self.trg_sentence[:pos] + self.trg_sentence[pos+1:] def _top_n(self, scores, sort=False): """Sorted indices of beam_size best entries along axis 1""" costs = -scores costs[:, utils.EOS_ID] = utils.INF top_n_indices = np.argpartition( costs, self.beam_size, axis=1)[:, :self.beam_size] if not sort: return top_n_indices b_indices = np.expand_dims(np.arange(top_n_indices.shape[0]), axis=1) sorted_indices = np.argsort(costs[b_indices, top_n_indices], axis=1) return top_n_indices[b_indices, sorted_indices] def predict_next(self): """Call the T2T model in self.mon_sess.""" next_sentences = {} logging.debug("EditT2T: Exploring score=%f sentence=%s" % (self.cur_score, " ".join(map(str, self.trg_sentence)))) n_trg_words = len(self.trg_sentence) if n_trg_words > EditT2TPredictor.MAX_SEQ_LEN: logging.warn("EditT2T: Target sentence exceeds maximum length (%d)" % EDITT2TPredictor.MAX_SEQ_LEN) return {utils.EOS_ID: 0.0} # Substitutions log_probs = self.mon_sess.run(self._log_probs, {self._inputs_var: self.src_sentence, self._targets_var: [self.trg_sentence]}) top_n = self._top_n(np.squeeze(log_probs, axis=0)) for pos, cur_token in enumerate(self.trg_sentence[:-1]): offset = EditT2TPredictor.SUB_OFFSET offset += EditT2TPredictor.POS_FACTOR * pos for token in top_n[pos]: if token != cur_token: next_sentences[offset + token] = self._sub_op(pos, token) # Insertions if n_trg_words < EditT2TPredictor.MAX_SEQ_LEN - 1: ins_trg_sentences = np.full((n_trg_words, n_trg_words+1), 999) for pos in range(n_trg_words): ins_trg_sentences[pos, :pos] = self.trg_sentence[:pos] ins_trg_sentences[pos, pos+1:] = self.trg_sentence[pos:] diag_log_probs = self.mon_sess.run(self._diag_log_probs, {self._inputs_var: self.src_sentence, self._targets_var: ins_trg_sentences}) top_n = self._top_n(np.squeeze(diag_log_probs, axis=1)) for pos in range(n_trg_words): offset = EditT2TPredictor.INS_OFFSET offset += EditT2TPredictor.POS_FACTOR * pos for token in top_n[pos]: next_sentences[offset + token] = self._ins_op(pos, token) # Deletions idx = EditT2TPredictor.DEL_OFFSET for pos in range(n_trg_words - 1): # -1: Do not delete EOS next_sentences[idx] = self._del_op(pos) idx += EditT2TPredictor.POS_FACTOR abs_scores = self._score(next_sentences, n_trg_words + 1) rel_scores = {i: s - self.cur_score for i, s in abs_scores.items()} rel_scores[utils.EOS_ID] = 0.0 return rel_scores def _score(self, sentences, n_trg_words=1): max_n_sens = max(1, self.batch_size // n_trg_words) scores = {} batch_ids = [] batch_sens = [] for idx, trg_sentence in sentences.items(): score = self.cache.get(trg_sentence) if score is None: batch_ids.append(idx) np_sen = np.zeros(n_trg_words, dtype=np.int) np_sen[:len(trg_sentence)] = trg_sentence batch_sens.append(np_sen) if len(batch_ids) >= max_n_sens: self._score_single_batch(scores, batch_ids, batch_sens) batch_ids = [] batch_sens = [] else: scores[idx] = score self._score_single_batch(scores, batch_ids, batch_sens) return scores def _score_single_batch(self, scores, ids, trg_sentences): "Score sentences and add them to scores and the cache.""" if not ids: return batch_scores = self.mon_sess.run(self._sentence_scores, {self._inputs_var: self.src_sentence, self._targets_var: np.stack(trg_sentences)}) for idx, sen, score in zip(ids, trg_sentences, batch_scores): self.cache.add(sen, score) scores[idx] = score def _update_cur_score(self): self.cur_score = self.cache.get(self.trg_sentence) if self.cur_score is None: scores = self._score({1: self.trg_sentence}, len(self.trg_sentence)) self.cur_score = scores[1] self.cache.add(self.trg_sentence, self.cur_score) def initialize(self, src_sentence): """Set src_sentence, reset consumed.""" if self.initial_trg_sentences is None: self.trg_sentence = [text_encoder.EOS_ID] else: self.trg_sentence = self.initial_trg_sentences[self.current_sen_id] self.src_sentence = utils.oov_to_unk( src_sentence + [text_encoder.EOS_ID], self.src_vocab_size, self._t2t_unk_id) self.cache = SimpleTrie() self._update_cur_score() logging.debug("Initial score: %f" % self.cur_score) def consume(self, word): """Append ``word`` to the current history.""" if word == utils.EOS_ID: return pos = (word // EditT2TPredictor.POS_FACTOR) \ % (EditT2TPredictor.MAX_SEQ_LEN + 1) token = word % EditT2TPredictor.POS_FACTOR # TODO(fstahlberg): Do not hard code the following section op = word // 100000000 if op == 1: # Insertion self.trg_sentence = self._ins_op(pos, token) elif op == 2: # Substitution self.trg_sentence = self._sub_op(pos, token) elif op == 3: # Deletion self.trg_sentence = self._del_op(pos) else: logging.warn("Invalid edit descriptor %d. Ignoring..." % word) self._update_cur_score() self.cache.add(self.trg_sentence, utils.NEG_INF) def get_state(self): """The predictor state is the complete target sentence.""" return self.trg_sentence, self.cur_score def set_state(self, state): """The predictor state is the complete target sentence.""" self.trg_sentence, self.cur_score = state def is_equal(self, state1, state2): """Returns true if the target sentence is the same """ return state1[0] == state2[0]
class TensorFlowNMTPredictor(Predictor): '''Neural MT predictor''' def __init__(self, enable_cache, config, session): super(TensorFlowNMTPredictor, self).__init__() self.config = config self.session = session # Add missing entries in config if self.config['encoder'] == "bow": self.config['init_backward'] = False self.config['use_seqlen'] = False else: self.config['bow_init_const'] = False self.config['use_bow_mask'] = False # Load tensorflow model self.model, self.training_graph, self.encoding_graph, \ self.single_step_decoding_graph, self.buckets = tf_model_utils.load_model(self.session, config) self.model.batch_size = 1 # We decode one sentence at a time. self.enc_out = {} self.decoder_input = [tf_data_utils.GO_ID] self.dec_state = {} self.bucket_id = -1 self.num_heads = 1 self.word_count = 0 if config['no_pad_symbol']: # This needs to be set in tensorflow data_utils for correct source masks tf_data_utils.no_pad_symbol() logging.info("UNK_ID=%d" % tf_data_utils.UNK_ID) logging.info("PAD_ID=%d" % tf_data_utils.PAD_ID) self.enable_cache = enable_cache if self.enable_cache: logging.info("Cache enabled..") def initialize(self, src_sentence): # src_sentence is list of integers, without <s> and </s> self.enc_out = {} self.decoder_input = [tf_data_utils.GO_ID] self.dec_state = {} self.word_count = 0 self.consumed = [] self.posterior_cache = SimpleTrie() self.states_cache = SimpleTrie() src_sentence = [w if w < self.config['src_vocab_size'] else tf_data_utils.UNK_ID for w in src_sentence] if self.config['add_src_eos']: src_sentence.append(tf_data_utils.EOS_ID) feasible_buckets = [b for b in xrange(len(self.buckets)) if self.buckets[b][0] >= len(src_sentence)] if not feasible_buckets: # Get a new bucket bucket = tf_model_utils.make_bucket(len(src_sentence)) logging.info("Add new bucket={} and update model".format(bucket)) self.buckets.append(bucket) self.model.update_buckets(self.buckets) self.bucket_id = len(self.buckets) - 1 else: self.bucket_id = min(feasible_buckets) encoder_inputs, _, _, sequence_length, src_mask, bow_mask = self.training_graph.get_batch( {self.bucket_id: [(src_sentence, [])]}, self.bucket_id, self.config['encoder']) logging.info("bucket={}".format(self.buckets[self.bucket_id])) last_enc_state, self.enc_out = self.encoding_graph.encode( self.session, encoder_inputs, self.bucket_id, sequence_length) # Initialize decoder state with last encoder state self.dec_state["dec_state"] = last_enc_state for a in xrange(self.num_heads): self.dec_state["dec_attns_%d" % a] = np.zeros((1, self.enc_out['enc_v_0'].size), dtype=np.float32) if self.config['use_src_mask']: self.dec_state["src_mask"] = src_mask self.src_mask_orig = src_mask.copy() if self.config['use_bow_mask']: self.dec_state["bow_mask"] = bow_mask self.bow_mask_orig = bow_mask.copy() def is_history_cachable(self): """Returns true if cache is enabled and history contains UNK """ if not self.enable_cache: return False for w in self.consumed: if w == tf_data_utils.UNK_ID: return True return False def predict_next(self): # should return list, numpy array, or dictionary if self.decoder_input[0] == tf_data_utils.EOS_ID: # Predict EOS return {tf_data_utils.EOS_ID: 0} use_cache = self.is_history_cachable() if use_cache: posterior = self.posterior_cache.get(self.consumed) if not posterior is None: logging.debug("Loaded NMT posterior from cache for %s" % self.consumed) return posterior output, self.dec_state = self.single_step_decoding_graph.decode(self.session, self.enc_out, self.dec_state, self.decoder_input, self.bucket_id, self.config['use_src_mask'], self.word_count, self.config['use_bow_mask']) if use_cache: self.posterior_cache.add(self.consumed, output[0]) return output[0] def get_unk_probability(self, posterior): # posterior is the returned value of the last predict_next call return posterior[utils.UNK_ID] if len(posterior) > 1 else float("-inf") def consume(self, word): if word >= self.config['trg_vocab_size']: word = tf_data_utils.UNK_ID # history is kept according to nmt vocab self.consumed.append(word) use_cache = self.is_history_cachable() if use_cache: s = self.states_cache.get(self.consumed) if not s is None: logging.debug("Loaded NMT decoder states from cache for %s" % self.consumed) states = copy.deepcopy(s) self.decoder_input = states[0] self.dec_state = states[1] self.word_count = states[2] return self.decoder_input = [word] self.word_count = self.word_count + 1 if use_cache: states = (self.decoder_input, self.dec_state, self.word_count) self.states_cache.add(self.consumed, copy.deepcopy(states)) def get_state(self): return (self.decoder_input, self.dec_state, self.word_count, self.consumed) def set_state(self, state): self.decoder_input, self.dec_state, self.word_count, self.consumed = state def is_equal(self, state1, state2): """Returns true if the history is the same """ _, _, _, consumed1 = state1 _, _, _, consumed2 = state2 return consumed1 == consumed2
class Word2charPredictor(UnboundedVocabularyPredictor): """This predictor wraps word level predictors when SGNMT is running on the character level. The mapping between word ID and character ID sequence is loaded from the file system. All characters which do not appear in that mapping are treated as word boundary makers. The wrapper blocks consume and predict_next calls until a word boundary marker is consumed, and updates the slave predictor according the word between the last two word boundaries. The mapping is done only on the target side, and the source sentences are passed through as they are. To use alternative tokenization on the source side, see the altsrc predictor wrapper. The word2char wrapper is always an ``UnboundedVocabularyPredictor``. """ def __init__(self, map_path, slave_predictor): """Creates a new word2char wrapper predictor. The map_path file has to be plain text files, each line containing the mapping from a word index to the character index sequence (format: word char1 char2... charn). Args: map_path (string): Path to the mapping file slave_predictor (Predictor): Instance of the predictor with a different wmap than SGNMT """ super(Word2charPredictor, self).__init__() self.slave_predictor = slave_predictor self.words = SimpleTrie() self.word_chars = {} with open(map_path) as f: for line in f: l = [int(x) for x in line.strip().split()] word = l[0] chars = l[1:] self.words.add(chars, word) for c in chars: self.word_chars[c] = True if isinstance(slave_predictor, UnboundedVocabularyPredictor): self._get_stub_prob = self._get_stub_prob_unbounded self._start_new_word = self._start_new_word_unbounded else: self._get_stub_prob = self._get_stub_prob_bounded self._start_new_word = self._start_new_word_bounded def initialize(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize(src_sentence) self._start_new_word() def initialize_heuristic(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize_heuristic(src_sentence) def _update_slave_vars(self, posterior): self.slave_unk = self.slave_predictor.get_unk_probability(posterior) self.slave_go = common_get(posterior, utils.GO_ID, self.slave_unk) self.slave_eos = common_get(posterior, utils.EOS_ID, self.slave_unk) def _start_new_word_unbounded(self): """start_new_word implementation for unbounded vocabulary slave predictors. Needs to set slave_go, slave_eos, and slave_unk """ self.word_stub = [] posterior = self.slave_predictor.predict_next([utils.UNK_ID, utils.GO_ID, utils.EOS_ID]) self._update_slave_vars(posterior) def _start_new_word_bounded(self): """start_new_word implementation for bounded vocabulary slave predictors. Needs to set slave_go, slave_eos, slave_unk, and slave_posterior """ self.word_stub = [] self.slave_posterior = self.slave_predictor.predict_next() self._update_slave_vars(self.slave_posterior) def _get_stub_prob_unbounded(self): """get_stub_prob implementation for unbounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) if word: posterior = self.slave_predictor.predict_next([word]) return common_get(posterior, word, self.slave_unk) return self.slave_unk def _get_stub_prob_bounded(self): """get_stub_prob implementation for bounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) return common_get(self.slave_posterior, word if word else utils.UNK_ID, self.slave_unk) def predict_next(self, trgt_words): posterior = {} stub_prob = False for ch in trgt_words: if ch in self.word_chars: posterior[ch] = 0.0 else: # Word boundary marker if stub_prob is False: stub_prob = self._get_stub_prob() if self.word_stub else 0.0 posterior[ch] = stub_prob if utils.GO_ID in posterior: posterior[utils.GO_ID] += self.slave_go if utils.EOS_ID in posterior: posterior[utils.EOS_ID] += self.slave_eos return posterior def get_unk_probability(self, posterior): """This is about the unkown character, not word. Since the word level slave predictor has no notion of the unknown character, we return NEG_INF unconditionally. """ return NEG_INF def consume(self, word): """If ``word`` is a word boundary marker, truncate ``word_stub`` and let the slave predictor consume word_stub. Otherwise, extend ``word_stub`` by the character. """ if word in self.word_chars: self.word_stub.append(word) elif self.word_stub: word = self.words.get(self.word_stub) self.slave_predictor.consume(word if word else utils.UNK_ID) self._start_new_word() def get_state(self): """Pass through to slave predictor """ return self.word_stub, self.slave_predictor.get_state() def set_state(self, state): """Pass through to slave predictor """ self.word_stub, slave_state = state self.slave_predictor.set_state(slave_state) def reset(self): """Pass through to slave predictor """ self.slave_predictor.reset() def estimate_future_cost(self, hypo): """Not supported """ logging.warn("Cannot use future cost estimates of predictors " "wrapped by word2char") return 0.0 def set_current_sen_id(self, cur_sen_id): """We need to override this method to propagate current\_ sentence_id to the slave predictor """ super(Word2charPredictor, self).set_current_sen_id(cur_sen_id) self.slave_predictor.set_current_sen_id(cur_sen_id) def is_equal(self, state1, state2): """Pass through to slave predictor """ stub1, slave_state1 = state1 stub2, slave_state2 = state2 return (stub1 == stub2 and self.slave_predictor.is_equal(slave_state1, slave_state2))
class Word2charPredictor(UnboundedVocabularyPredictor): """This predictor wraps word level predictors when SGNMT is running on the character level. The mapping between word ID and character ID sequence is loaded from the file system. All characters which do not appear in that mapping are treated as word boundary makers. The wrapper blocks consume and predict_next calls until a word boundary marker is consumed, and updates the slave predictor according the word between the last two word boundaries. The mapping is done only on the target side, and the source sentences are passed through as they are. To use alternative tokenization on the source side, see the altsrc predictor wrapper. The word2char wrapper is always an ``UnboundedVocabularyPredictor``. """ def __init__(self, map_path, slave_predictor): """Creates a new word2char wrapper predictor. The map_path file has to be plain text files, each line containing the mapping from a word index to the character index sequence (format: word char1 char2... charn). Args: map_path (string): Path to the mapping file slave_predictor (Predictor): Instance of the predictor with a different wmap than SGNMT """ super(Word2charPredictor, self).__init__() self.slave_predictor = slave_predictor self.words = SimpleTrie() self.word_chars = {} with open(map_path) as f: for line in f: l = [int(x) for x in line.strip().split()] word = l[0] chars = l[1:] self.words.add(chars, word) for c in chars: self.word_chars[c] = True if isinstance(slave_predictor, UnboundedVocabularyPredictor): self._get_stub_prob = self._get_stub_prob_unbounded self._start_new_word = self._start_new_word_unbounded else: self._get_stub_prob = self._get_stub_prob_bounded self._start_new_word = self._start_new_word_bounded def initialize(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize(src_sentence) self._start_new_word() def initialize_heuristic(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize_heuristic(src_sentence) def _update_slave_vars(self, posterior): self.slave_unk = self.slave_predictor.get_unk_probability(posterior) self.slave_go = common_get(posterior, utils.GO_ID, self.slave_unk) self.slave_eos = common_get(posterior, utils.EOS_ID, self.slave_unk) def _start_new_word_unbounded(self): """start_new_word implementation for unbounded vocabulary slave predictors. Needs to set slave_go, slave_eos, and slave_unk """ self.word_stub = [] posterior = self.slave_predictor.predict_next([utils.UNK_ID, utils.GO_ID, utils.EOS_ID]) self._update_slave_vars(posterior) def _start_new_word_bounded(self): """start_new_word implementation for bounded vocabulary slave predictors. Needs to set slave_go, slave_eos, slave_unk, and slave_posterior """ self.word_stub = [] self.slave_posterior = self.slave_predictor.predict_next() self._update_slave_vars(self.slave_posterior) def _get_stub_prob_unbounded(self): """get_stub_prob implementation for unbounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) if word: posterior = self.slave_predictor.predict_next([word]) return common_get(posterior, word, self.slave_unk) return self.slave_unk def _get_stub_prob_bounded(self): """get_stub_prob implementation for bounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) return common_get(self.slave_posterior, word if word else utils.UNK_ID, self.slave_unk) def predict_next(self, trgt_words): posterior = {} stub_prob = False for ch in trgt_words: if ch in self.word_chars: posterior[ch] = 0.0 else: # Word boundary marker if stub_prob is False: stub_prob = self._get_stub_prob() if self.word_stub else 0.0 posterior[ch] = stub_prob if utils.GO_ID in posterior: posterior[utils.GO_ID] += self.slave_go if utils.EOS_ID in posterior: posterior[utils.EOS_ID] += self.slave_eos return posterior def get_unk_probability(self, posterior): """This is about the unkown character, not word. Since the word level slave predictor has no notion of the unknown character, we return NEG_INF unconditionally. """ return NEG_INF def consume(self, word): """If ``word`` is a word boundary marker, truncate ``word_stub`` and let the slave predictor consume word_stub. Otherwise, extend ``word_stub`` by the character. """ if word in self.word_chars: self.word_stub.append(word) elif self.word_stub: word = self.words.get(self.word_stub) self.slave_predictor.consume(word if word else utils.UNK_ID) self._start_new_word() def get_state(self): """Pass through to slave predictor """ return self.word_stub, self.slave_predictor.get_state() def set_state(self, state): """Pass through to slave predictor """ self.word_stub, slave_state = state self.slave_predictor.set_state(slave_state) def estimate_future_cost(self, hypo): """Not supported """ logging.warn("Cannot use future cost estimates of predictors " "wrapped by word2char") return 0.0 def set_current_sen_id(self, cur_sen_id): """We need to override this method to propagate current\_ sentence_id to the slave predictor """ super(Word2charPredictor, self).set_current_sen_id(cur_sen_id) self.slave_predictor.set_current_sen_id(cur_sen_id) def is_equal(self, state1, state2): """Pass through to slave predictor """ stub1, slave_state1 = state1 stub2, slave_state2 = state2 return (stub1 == stub2 and self.slave_predictor.is_equal(slave_state1, slave_state2))
class BagOfWordsPredictor(Predictor): """This predictor is similar to the forced predictor, but it does not enforce the word order in the reference. Therefore, it assigns 1 to all hypotheses which have the words in the reference in any order, and -inf to all other hypos. """ def __init__(self, trg_test_file, accept_subsets=False, accept_duplicates=False, heuristic_scores_file="", collect_stats_strategy='best', heuristic_add_consumed=False, heuristic_add_remaining=True, diversity_heuristic_factor=-1.0, equivalence_vocab=-1): """Creates a new bag-of-words predictor. Args: trg_test_file (string): Path to the plain text file with the target sentences. Must have the same number of lines as the number of source sentences to decode. The word order in the target sentences is not relevant for this predictor. accept_subsets (bool): If true, this predictor permits EOS even if the bag is not fully consumed yet accept_duplicates (bool): If true, counts are not updated when a word is consumed. This means that we allow a word in a bag to appear multiple times heuristic_scores_file (string): Path to the unigram scores which are used if this predictor estimates future costs collect_stats_strategy (string): best, full, or all. Defines how unigram estimates are collected for heuristic heuristic_add_consumed (bool): Set to true to add the difference between actual partial score and unigram estimates of consumed words to the predictor heuristic heuristic_add_remaining (bool): Set to true to add the sum of unigram scores of words remaining in the bag to the predictor heuristic diversity_heuristic_factor (float): Factor for diversity heuristic which penalizes hypotheses with the same bag as full hypos equivalence_vocab (int): If positive, predictor states are considered equal if the the remaining words within that vocab and OOVs regarding this vocab are the same. Only relevant when using hypothesis recombination """ super(BagOfWordsPredictor, self).__init__() with open(trg_test_file) as f: self.lines = f.read().splitlines() if heuristic_scores_file: self.estimates = FileUnigramTable(heuristic_scores_file) elif collect_stats_strategy == 'best': self.estimates = BestStatsUnigramTable() elif collect_stats_strategy == 'full': self.estimates = FullStatsUnigramTable() elif collect_stats_strategy == 'all': self.estimates = AllStatsUnigramTable() else: logging.error("Unknown statistics collection strategy") self.accept_subsets = accept_subsets self.accept_duplicates = accept_duplicates self.heuristic_add_consumed = heuristic_add_consumed self.heuristic_add_remaining = heuristic_add_remaining self.equivalence_vocab = equivalence_vocab if accept_duplicates and not accept_subsets: logging.error("You enabled bow_accept_duplicates but not bow_" "accept_subsets. Therefore, the bow predictor will " "never accept end-of-sentence and could cause " "an infinite loop in the search strategy.") self.diversity_heuristic_factor = diversity_heuristic_factor self.diverse_heuristic = (diversity_heuristic_factor > 0.0) def get_unk_probability(self, posterior): """Returns negative infinity unconditionally: Words which are not in the target sentence have assigned probability 0 by this predictor. """ return NEG_INF def predict_next(self): """If the bag is empty, the only allowed symbol is EOS. Otherwise, return the list of keys in the bag. """ if not self.bag: # Empty bag return {utils.EOS_ID: 0.0} ret = {w: 0.0 for w in self.bag.iterkeys()} if self.accept_subsets: ret[utils.EOS_ID] = 0.0 return ret def initialize(self, src_sentence): """Creates a new bag for the current target sentence.. Args: src_sentence (list): Not used """ self.best_hypo_score = NEG_INF self.bag = {} for w in self.lines[self.current_sen_id].strip().split(): int_w = int(w) self.bag[int_w] = self.bag.get(int_w, 0) + 1 self.full_bag = dict(self.bag) def consume(self, word): """Updates the bag by deleting the consumed word. Args: word (int): Next word to consume """ if word == utils.EOS_ID: self.bag = {} return if not word in self.bag: logging.warn("Consuming word which is not in bag-of-words!") return cnt = self.bag.pop(word) if cnt > 1 and not self.accept_duplicates: self.bag[word] = cnt - 1 def get_state(self): """State of this predictor is the current bag """ return self.bag def set_state(self, state): """State of this predictor is the current bag """ self.bag = state def reset(self): """Empty method. """ pass def initialize_heuristic(self, src_sentence): """Calls ``reset`` of the used unigram table with estimates ``self.estimates`` to clear all statistics from the previous sentence Args: src_sentence (list): Not used """ self.estimates.reset() if self.diverse_heuristic: self.explored_bags = SimpleTrie() def notify(self, message, message_type=MESSAGE_TYPE_DEFAULT): """This gets called if this predictor observes the decoder. It updates unigram heuristic estimates via passing through this message to the unigram table ``self.estimates``. """ self.estimates.notify(message, message_type) if self.diverse_heuristic and message_type == MESSAGE_TYPE_FULL_HYPO: self._update_explored_bags(message) def _update_explored_bags(self, hypo): """This is called if diversity heuristic is enabled. It updates ``self.explored_bags`` """ sen = hypo.trgt_sentence for l in xrange(len(sen)): key = sen[:l] key.sort() cnt = self.explored_bags.get(key) if not cnt: cnt = 0.0 self.explored_bags.add(key, cnt + 1.0) def estimate_future_cost(self, hypo): """The bow predictor comes with its own heuristic function. We use the sum of scores of the remaining words as future cost estimator. """ acc = 0.0 if self.heuristic_add_remaining: remaining = dict(self.full_bag) remaining[utils.EOS_ID] = 1 for w in hypo.trgt_sentence: remaining[w] -= 1 acc -= sum([ cnt * self.estimates.estimate(w) for w, cnt in remaining.iteritems() ]) if self.diverse_heuristic: key = list(hypo.trgt_sentence) key.sort() cnt = self.explored_bags.get(key) if cnt: acc += cnt * self.diversity_heuristic_factor if self.heuristic_add_consumed: acc -= hypo.score - sum([ self.estimates.estimate(w, -1000.0) for w in hypo.trgt_sentence ]) return acc def _get_unk_bag(self, org_bag): if self.equivalence_vocab <= 0: return org_bag unk_bag = {} for word, cnt in org_bag.iteritems(): idx = word if word < self.equivalence_vocab else utils.UNK_ID unk_bag[idx] = unk_bag.get(idx, 0) + cnt return unk_bag def is_equal(self, state1, state2): """Returns true if the bag is the same """ return self._get_unk_bag(state1) == self._get_unk_bag(state2)
class NgramCountPredictor(Predictor): """This predictor counts the number of n-grams in hypotheses. n-gram posteriors are loaded from a file. The predictor score is the sum of all n-gram posteriors in a hypothesis. """ def __init__(self, path, order=0, discount_factor=-1.0): """Creates a new ngram count predictor instance. Args: path (string): Path to the n-gram posteriors. File format: <ngram> : <score> (one ngram per line). Use placeholder %d for sentence id. order (int): If positive, count n-grams of the specified order. Otherwise, count all n-grams discount_factor (float): If non-negative, discount n-gram posteriors by this factor each time they are consumed """ super(NgramCountPredictor, self).__init__() self.path = path self.order = order self.discount_factor = discount_factor def get_unk_probability(self, posterior): """Always return 0.0 """ return 0.0 def predict_next(self): """Composes the posterior vector by collecting all ngrams which are consistent with the current history. """ posterior = {} for i in reversed(range(len(self.cur_history)+1)): scores = self.ngrams.get(self.cur_history[i:]) if scores: factors = False if self.discount_factor >= 0.0: factors = self.discounts.get(self.cur_history[i:]) if not factors: for w,score in scores.iteritems(): posterior[w] = posterior.get(w, 0.0) + score else: for w,score in scores.iteritems(): posterior[w] = posterior.get(w, 0.0) + \ factors.get(w, 1.0) * score return posterior def _load_posteriors(self, path): """Sets up self.max_history_len and self.ngrams """ self.max_history_len = 0 self.ngrams = SimpleTrie() logging.debug("Loading n-gram scores from %s..." % path) with open(path) as f: for line in f: ngram,score = line.split(':') words = [int(w) for w in ngram.strip().split()] if self.order > 0 and len(words) != self.order: continue hist = words[:-1] last_word = words[-1] if last_word == utils.GO_ID: continue self.max_history_len = max(self.max_history_len, len(hist)) p = self.ngrams.get(hist) if p: p[last_word] = float(score.strip()) else: self.ngrams.add(hist, {last_word: float(score.strip())}) def initialize(self, src_sentence): """Loads n-gram posteriors and resets history. Args: src_sentence (list): not used """ self._load_posteriors(utils.get_path(self.path, self.current_sen_id+1)) self.cur_history = [utils.GO_ID] self.discounts = SimpleTrie() def consume(self, word): """Adds ``word`` to the current history. Shorten if the extended history exceeds ``max_history_len``. Args: word (int): Word to add to the history. """ self.cur_history.append(word) if len(self.cur_history) > self.max_history_len: self.cur_history = self.cur_history[-self.max_history_len:] if self.discount_factor >= 0.0: for i in range(len(self.cur_history)): key = self.cur_history[i:-1] factors = self.discounts.get(key) if not factors: factors = {word: self.discount_factor} else: factors[word] = factors.get(word, 1.0)*self.discount_factor self.discounts.add(key, factors) def get_state(self): """Current history is the predictor state """ return self.cur_history,self.discounts def set_state(self, state): """Current history is the predictor state """ self.cur_history,self.discounts = state def is_equal(self, state1, state2): """Hypothesis recombination is not supported if discounting is enabled. """ if self.discount_factor >= 0.0: return False hist1 = state1[0] hist2 = state2[0] if hist1 == hist2: # Return true if histories match return True if len(hist1) > len(hist2): hist_long = hist1 hist_short = hist2 else: hist_long = hist2 hist_short = hist1 min_len = len(hist_short) for n in xrange(1, min_len+1): # Look up non matching in self.ngrams key1 = hist1[-n:] key2 = hist2[-n:] if key1 != key2: if self.ngrams.get(key1) or self.ngrams.get(key2): return False for n in xrange(min_len+1, len(hist_long)+1): if self.ngrams.get(hist_long[-n:]): return False return True
class TensorFlowNMTPredictor(Predictor): '''Neural MT predictor''' def __init__(self, enable_cache, config, session): super(TensorFlowNMTPredictor, self).__init__() self.config = config self.session = session # Add missing entries in config if self.config['encoder'] == "bow": self.config['init_backward'] = False self.config['use_seqlen'] = False else: self.config['bow_init_const'] = False self.config['use_bow_mask'] = False # Load tensorflow model self.model, self.training_graph, self.encoding_graph, \ self.single_step_decoding_graph, self.buckets = tf_model_utils.load_model(self.session, config) self.model.batch_size = 1 # We decode one sentence at a time. self.enc_out = {} self.decoder_input = [tf_data_utils.GO_ID] self.dec_state = {} self.bucket_id = -1 self.num_heads = 1 self.word_count = 0 if config['no_pad_symbol']: # This needs to be set in tensorflow data_utils for correct source masks tf_data_utils.no_pad_symbol() logging.info("UNK_ID=%d" % tf_data_utils.UNK_ID) logging.info("PAD_ID=%d" % tf_data_utils.PAD_ID) self.enable_cache = enable_cache if self.enable_cache: logging.info("Cache enabled..") def initialize(self, src_sentence): # src_sentence is list of integers, without <s> and </s> self.reset() self.posterior_cache = SimpleTrie() self.states_cache = SimpleTrie() src_sentence = [ w if w < self.config['src_vocab_size'] else tf_data_utils.UNK_ID for w in src_sentence ] if self.config['add_src_eos']: src_sentence.append(tf_data_utils.EOS_ID) feasible_buckets = [ b for b in xrange(len(self.buckets)) if self.buckets[b][0] >= len(src_sentence) ] if not feasible_buckets: # Get a new bucket bucket = tf_model_utils.make_bucket(len(src_sentence)) logging.info("Add new bucket={} and update model".format(bucket)) self.buckets.append(bucket) self.model.update_buckets(self.buckets) self.bucket_id = len(self.buckets) - 1 else: self.bucket_id = min(feasible_buckets) encoder_inputs, _, _, sequence_length, src_mask, bow_mask = self.training_graph.get_batch( {self.bucket_id: [(src_sentence, [])]}, self.bucket_id, self.config['encoder']) logging.info("bucket={}".format(self.buckets[self.bucket_id])) last_enc_state, self.enc_out = self.encoding_graph.encode( self.session, encoder_inputs, self.bucket_id, sequence_length) # Initialize decoder state with last encoder state self.dec_state["dec_state"] = last_enc_state for a in xrange(self.num_heads): self.dec_state["dec_attns_%d" % a] = np.zeros( (1, self.enc_out['enc_v_0'].size), dtype=np.float32) if self.config['use_src_mask']: self.dec_state["src_mask"] = src_mask self.src_mask_orig = src_mask.copy() if self.config['use_bow_mask']: self.dec_state["bow_mask"] = bow_mask self.bow_mask_orig = bow_mask.copy() def is_history_cachable(self): """Returns true if cache is enabled and history contains UNK """ if not self.enable_cache: return False for w in self.consumed: if w == tf_data_utils.UNK_ID: return True return False def predict_next(self): # should return list, numpy array, or dictionary if self.decoder_input[0] == tf_data_utils.EOS_ID: # Predict EOS return {tf_data_utils.EOS_ID: 0} use_cache = self.is_history_cachable() if use_cache: posterior = self.posterior_cache.get(self.consumed) if not posterior is None: logging.debug("Loaded NMT posterior from cache for %s" % self.consumed) return posterior output, self.dec_state = self.single_step_decoding_graph.decode( self.session, self.enc_out, self.dec_state, self.decoder_input, self.bucket_id, self.config['use_src_mask'], self.word_count, self.config['use_bow_mask']) if use_cache: self.posterior_cache.add(self.consumed, output[0]) return output[0] def get_unk_probability(self, posterior): # posterior is the returned value of the last predict_next call return posterior[utils.UNK_ID] if len(posterior) > 1 else float("-inf") def consume(self, word): if word >= self.config['trg_vocab_size']: word = tf_data_utils.UNK_ID # history is kept according to nmt vocab logging.debug("Consume word={}".format(word)) self.consumed.append(word) use_cache = self.is_history_cachable() if use_cache: s = self.states_cache.get(self.consumed) if not s is None: logging.debug("Loaded NMT decoder states from cache for %s" % self.consumed) states = copy.deepcopy(s) self.decoder_input = states[0] self.dec_state = states[1] self.word_count = states[2] return self.decoder_input = [word] self.word_count = self.word_count + 1 if use_cache: states = (self.decoder_input, self.dec_state, self.word_count) self.states_cache.add(self.consumed, copy.deepcopy(states)) def get_state(self): return (self.decoder_input, self.dec_state, self.word_count, self.consumed) def set_state(self, state): self.decoder_input, self.dec_state, self.word_count, self.consumed = state def reset(self): self.enc_out = {} self.decoder_input = [tf_data_utils.GO_ID] self.dec_state = {} self.word_count = 0 self.consumed = [] def is_equal(self, state1, state2): """Returns true if the history is the same """ _, _, _, consumed1 = state1 _, _, _, consumed2 = state2 return consumed1 == consumed2
class EditT2TPredictor(_BaseTensor2TensorPredictor): """This predictor can be used for T2T models conditioning on the full target sentence. The predictor state is a full target sentence. The state can be changed by insertions, substitutions, and deletions of single tokens, whereas each operation is encoded as SGNMT token in the following way: 1xxxyyyyy: Insert the token yyyyy at position xxx. 2xxxyyyyy: Replace the xxx-th word with the token yyyyy. 3xxx00000: Delete the xxx-th token. """ INS_OFFSET = 100000000 SUB_OFFSET = 200000000 DEL_OFFSET = 300000000 POS_FACTOR = 100000 MAX_SEQ_LEN = 999 def __init__(self, src_vocab_size, trg_vocab_size, model_name, problem_name, hparams_set_name, trg_test_file, beam_size, t2t_usr_dir, checkpoint_dir, t2t_unk_id=None, n_cpu_threads=-1, max_terminal_id=-1, pop_id=-1): """Creates a new edit T2T predictor. This constructor is similar to the constructor of T2TPredictor but creates a different computation graph which retrieves scores at each target position, not only the last one. Args: src_vocab_size (int): Source vocabulary size. trg_vocab_size (int): Target vocabulary size. model_name (string): T2T model name. problem_name (string): T2T problem name. hparams_set_name (string): T2T hparams set name. trg_test_file (string): Path to a plain text file with initial target sentences. Can be empty. beam_size (int): Determines how many substitutions and insertions are considered at each position. t2t_usr_dir (string): See --t2t_usr_dir in tensor2tensor. checkpoint_dir (string): Path to the T2T checkpoint directory. The predictor will load the top most checkpoint in the `checkpoints` file. t2t_unk_id (int): If set, use this ID to get UNK scores. If None, UNK is always scored with -inf. n_cpu_threads (int): Number of TensorFlow CPU threads. max_terminal_id (int): If positive, maximum terminal ID. Needs to be set for syntax-based T2T models. pop_id (int): If positive, ID of the POP or closing bracket symbol. Needs to be set for syntax-based T2T models. """ super(EditT2TPredictor, self).__init__(t2t_usr_dir, checkpoint_dir, src_vocab_size, trg_vocab_size, t2t_unk_id, n_cpu_threads, max_terminal_id, pop_id) if not model_name or not problem_name or not hparams_set_name: logging.fatal( "Please specify t2t_model, t2t_problem, and t2t_hparams_set!") raise AttributeError if trg_vocab_size >= EditT2TPredictor.POS_FACTOR: logging.fatal("Target vocabulary size (%d) must be less than %d!" % (trg_vocab_size, EditT2TPredictor.POS_FACTOR)) raise AttributeError self.beam_size = max(1, beam_size // 10) + 1 self.batch_size = 2048 # TODO(fstahlberg): Move to config self.initial_trg_sentences = None if trg_test_file: self.initial_trg_sentences = [] with open(trg_test_file) as f: for line in f: self.initial_trg_sentences.append(utils.oov_to_unk( [int(w) for w in line.strip().split()] + [utils.EOS_ID], self.trg_vocab_size, self._t2t_unk_id)) predictor_graph = tf.Graph() with predictor_graph.as_default() as g: hparams = trainer_lib.create_hparams(hparams_set_name) self._add_problem_hparams(hparams, problem_name) translate_model = registry.model(model_name)( hparams, tf.estimator.ModeKeys.EVAL) self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None], name="sgnmt_inputs") self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None, None], name="sgnmt_targets") shp = tf.shape(self._targets_var) bsz = shp[0] inputs = tf.tile(tf.expand_dims(self._inputs_var, 0), [bsz, 1]) features = {"inputs": expand_input_dims_for_t2t(inputs, batched=True), "targets": expand_input_dims_for_t2t(self._targets_var, batched=True)} translate_model.prepare_features_for_infer(features) translate_model._fill_problem_hparams_features(features) logits, _ = translate_model(features) logits = tf.squeeze(logits, [2, 3]) self._log_probs = log_prob_from_logits(logits) diag_logits = gather_2d(logits, tf.expand_dims(tf.range(bsz), 1)) self._diag_log_probs = log_prob_from_logits(diag_logits) no_pad = tf.cast(tf.not_equal( self._targets_var, text_encoder.PAD_ID), tf.float32) flat_bsz = shp[0] * shp[1] word_scores = gather_2d( tf.reshape(self._log_probs, [flat_bsz, -1]), tf.reshape(self._targets_var, [flat_bsz, 1])) word_scores = tf.reshape(word_scores, (shp[0], shp[1])) * no_pad self._sentence_scores = tf.reduce_sum(word_scores, -1) self.mon_sess = self.create_session() def _ins_op(self, pos, token): """Returns a copy of trg sentence after an insertion.""" return self.trg_sentence[:pos] + [token] + self.trg_sentence[pos:] def _sub_op(self, pos, token): """Returns a copy of trg sentence after a substitution.""" ret = list(self.trg_sentence) ret[pos] = token return ret def _del_op(self, pos): """Returns a copy of trg sentence after a deletion.""" return self.trg_sentence[:pos] + self.trg_sentence[pos+1:] def _top_n(self, scores, sort=False): """Sorted indices of beam_size best entries along axis 1""" costs = -scores costs[:, utils.EOS_ID] = utils.INF top_n_indices = np.argpartition( costs, self.beam_size, axis=1)[:, :self.beam_size] if not sort: return top_n_indices b_indices = np.expand_dims(np.arange(top_n_indices.shape[0]), axis=1) sorted_indices = np.argsort(costs[b_indices, top_n_indices], axis=1) return top_n_indices[b_indices, sorted_indices] def predict_next(self): """Call the T2T model in self.mon_sess.""" next_sentences = {} logging.debug("EditT2T: Exploring score=%f sentence=%s" % (self.cur_score, " ".join(map(str, self.trg_sentence)))) n_trg_words = len(self.trg_sentence) if n_trg_words > EditT2TPredictor.MAX_SEQ_LEN: logging.warn("EditT2T: Target sentence exceeds maximum length (%d)" % EDITT2TPredictor.MAX_SEQ_LEN) return {utils.EOS_ID: 0.0} # Substitutions log_probs = self.mon_sess.run(self._log_probs, {self._inputs_var: self.src_sentence, self._targets_var: [self.trg_sentence]}) top_n = self._top_n(np.squeeze(log_probs, axis=0)) for pos, cur_token in enumerate(self.trg_sentence[:-1]): offset = EditT2TPredictor.SUB_OFFSET offset += EditT2TPredictor.POS_FACTOR * pos for token in top_n[pos]: if token != cur_token: next_sentences[offset + token] = self._sub_op(pos, token) # Insertions if n_trg_words < EditT2TPredictor.MAX_SEQ_LEN - 1: ins_trg_sentences = np.full((n_trg_words, n_trg_words+1), 999) for pos in xrange(n_trg_words): ins_trg_sentences[pos, :pos] = self.trg_sentence[:pos] ins_trg_sentences[pos, pos+1:] = self.trg_sentence[pos:] diag_log_probs = self.mon_sess.run(self._diag_log_probs, {self._inputs_var: self.src_sentence, self._targets_var: ins_trg_sentences}) top_n = self._top_n(np.squeeze(diag_log_probs, axis=1)) for pos in xrange(n_trg_words): offset = EditT2TPredictor.INS_OFFSET offset += EditT2TPredictor.POS_FACTOR * pos for token in top_n[pos]: next_sentences[offset + token] = self._ins_op(pos, token) # Deletions idx = EditT2TPredictor.DEL_OFFSET for pos in xrange(n_trg_words - 1): # -1: Do not delete EOS next_sentences[idx] = self._del_op(pos) idx += EditT2TPredictor.POS_FACTOR abs_scores = self._score(next_sentences, n_trg_words + 1) rel_scores = {i: s - self.cur_score for i, s in abs_scores.iteritems()} rel_scores[utils.EOS_ID] = 0.0 return rel_scores def _score(self, sentences, n_trg_words=1): max_n_sens = max(1, self.batch_size // n_trg_words) scores = {} batch_ids = [] batch_sens = [] for idx, trg_sentence in sentences.iteritems(): score = self.cache.get(trg_sentence) if score is None: batch_ids.append(idx) np_sen = np.zeros(n_trg_words, dtype=np.int) np_sen[:len(trg_sentence)] = trg_sentence batch_sens.append(np_sen) if len(batch_ids) >= max_n_sens: self._score_single_batch(scores, batch_ids, batch_sens) batch_ids = [] batch_sens = [] else: scores[idx] = score self._score_single_batch(scores, batch_ids, batch_sens) return scores def _score_single_batch(self, scores, ids, trg_sentences): "Score sentences and add them to scores and the cache.""" if not ids: return batch_scores = self.mon_sess.run(self._sentence_scores, {self._inputs_var: self.src_sentence, self._targets_var: np.stack(trg_sentences)}) for idx, sen, score in zip(ids, trg_sentences, batch_scores): self.cache.add(sen, score) scores[idx] = score def _update_cur_score(self): self.cur_score = self.cache.get(self.trg_sentence) if self.cur_score is None: scores = self._score({1: self.trg_sentence}, len(self.trg_sentence)) self.cur_score = scores[1] self.cache.add(self.trg_sentence, self.cur_score) def initialize(self, src_sentence): """Set src_sentence, reset consumed.""" if self.initial_trg_sentences is None: self.trg_sentence = [text_encoder.EOS_ID] else: self.trg_sentence = self.initial_trg_sentences[self.current_sen_id] self.src_sentence = utils.oov_to_unk( src_sentence + [text_encoder.EOS_ID], self.src_vocab_size, self._t2t_unk_id) self.cache = SimpleTrie() self._update_cur_score() logging.debug("Initial score: %f" % self.cur_score) def consume(self, word): """Append ``word`` to the current history.""" if word == utils.EOS_ID: return pos = (word // EditT2TPredictor.POS_FACTOR) \ % (EditT2TPredictor.MAX_SEQ_LEN + 1) token = word % EditT2TPredictor.POS_FACTOR # TODO(fstahlberg): Do not hard code the following section op = word // 100000000 if op == 1: # Insertion self.trg_sentence = self._ins_op(pos, token) elif op == 2: # Substitution self.trg_sentence = self._sub_op(pos, token) elif op == 3: # Deletion self.trg_sentence = self._del_op(pos) else: logging.warn("Invalid edit descriptor %d. Ignoring..." % word) self._update_cur_score() self.cache.add(self.trg_sentence, utils.NEG_INF) def get_state(self): """The predictor state is the complete target sentence.""" return self.trg_sentence, self.cur_score def set_state(self, state): """The predictor state is the complete target sentence.""" self.trg_sentence, self.cur_score = state def is_equal(self, state1, state2): """Returns true if the target sentence is the same """ return state1[0] == state2[0]