def create_initial_stack(self, model, batch_placeholder, force_bos=True, **flags): inp = batch_placeholder['inp'] batch_size = tf.shape(inp)[0] initial_state = model.encode(batch_placeholder, **flags) initial_attnP = model.get_attnP(initial_state)[:, None] initial_tracked = nested_map(lambda x: x[:, None], self.get_tracked_outputs(initial_state)) if force_bos: initial_outputs = tf.cast( tf.fill((batch_size, 1), model.out_voc.bos), inp.dtype) initial_state = model.decode(initial_state, initial_outputs[:, 0], **flags) second_attnP = model.get_attnP(initial_state)[:, None] initial_attnP = tf.concat([initial_attnP, second_attnP], axis=1) initial_tracked = nested_map( lambda x, y: tf.concat([x, y[:, None]], axis=1), initial_tracked, self.get_tracked_outputs(initial_state), ) else: initial_outputs = tf.zeros((batch_size, 0), dtype=inp.dtype) initial_logp = tf.zeros([batch_size], dtype='float32') initial_mask = tf.ones_like([batch_size], dtype='bool') initial_len = tf.shape(initial_outputs)[1] return self.Stack(initial_len, initial_outputs, initial_logp, initial_mask, initial_state, initial_attnP, initial_tracked)
def step(self, model, stack, **flags): """ :type model: models.TranslateModel :param stack: beam search stack :return: new beam search stack """ out_len, out_seq, logp, mask, dec_states, attnP, tracked = stack # 1. sample batch_size = tf.shape(out_seq)[0] phony_slices = tf.range(batch_size) _, new_outputs, logp_next = model.sample(dec_states, logp, phony_slices, k=1, **flags) out_seq = tf.concat([out_seq, new_outputs], axis=1) logp = logp + logp_next[:, 0] * tf.cast(mask, 'float32') is_eos = tf.equal(new_outputs[:, 0], model.out_voc.eos) mask = tf.logical_and(mask, tf.logical_not(is_eos)) # 2. decode new_states = model.decode(dec_states, new_outputs[:, 0], **flags) attnP = tf.concat([attnP, model.get_attnP(new_states)[:, None]], axis=1) tracked = nested_map( lambda seq, new: tf.concat([seq, new[:, None]], axis=1), tracked, self.get_tracked_outputs(new_states)) return self.Stack(out_len + 1, out_seq, logp, mask, new_states, attnP, tracked)
def shuffle(self, dec_state, flat_indices): """ Selects hypotheses from model decoder state by given indices. :param dec_state: a nested structure of tensors representing model state :param flat_indices: int32 vector of indices to select :returns: dec state elements for given flat_indices only """ return nested_map(lambda var: tf.gather(var, flat_indices, axis=0), dec_state)
def __init__(self, model, batch_placeholder, min_len=None, max_len=None, beam_size=3, beam_spread=3, force_bos=True, if_no_eos='last', back_prop=True, swap_memory=False, **flags): """ Performs ingraph beam search for given input sequences (inp) Supports penalizing, pruning against best score and best score in beam (via beam_spread) :param model: something that implements TranslateModel :param batch_placeholder: whatever model can .encode, by default should be {'inp': int32 matrix [batch_size x time]} :param min_length: minimum valid output length. None means min_len=inp_len // 4 - 1 :param max_len: maximum hypothesis length to consider, float('inf') means unlimited, None means max_len=2*inp_len + 3, :param beam_size: maximum number of hypotheses that can pass from one beam search step to another. The rest is pruned. :param beam_spread: maximum difference in score between a hypothesis and current best hypothesis. Anything below that is pruned. :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide. :param if_no_eos: if 'last', will return unfinished hypos if there are no finished hypos by max_len elif 'initial', returns empty hypothesis :param back_prop: see tf.while_loop back_prop param :param swap_memory: see tf.while_loop swap_memory param :param **flags: whatever else you want to feed into model. This will be passed to encode, decode, etc. is_train - if True (default), enables dropouts and similar training-only stuff sampling_strategy - if "random", samples hypotheses proportionally to softmax(logits) otherwise(default) - takes top K hypotheses sampling_temperature - if sampling_strategy == "random", performs sampling ~ softmax(logits/sampling_temperature) """ print("Preparing BEAM SEARCH translate with params:", locals()) assert if_no_eos in ['last', 'initial'] assert np.isfinite(beam_spread) or max_len != float( 'inf'), "Must set maximum length if beam_spread is infinite" # initialize fields self.batch_placeholder = batch_placeholder inp_len = batch_placeholder.get( 'inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) self.min_len = min_len if min_len is not None else inp_len // 4 - 1 self.max_len = max_len if max_len is not None else 2 * inp_len + 3 self.beam_size, self.beam_spread = beam_size, beam_spread self.force_bos, self.if_no_eos = force_bos, if_no_eos # actual beam search first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) shape_invariants = nested_map( lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) def should_continue_translating(*stack): stack = self.BeamSearchStack(*stack) should_continue = self.should_extend_hypo(model, stack) return tf.reduce_any(should_continue) def expand_hypos(*stack): stack = self.BeamSearchStack(*stack) return self.beam_search_step(model, stack, **flags) last_stack = tf.while_loop( cond=should_continue_translating, body=expand_hypos, loop_vars=first_stack, shape_invariants=shape_invariants, back_prop=back_prop, swap_memory=swap_memory, ) # crop unnecessary EOSes that occur if no hypothesis is updated on several last steps actual_length = infer_length(last_stack.best_out, model.out_voc.eos) max_length = tf.reduce_max(actual_length) last_stack = last_stack._replace( best_out=last_stack.best_out[:, :max_length]) self.best_attnP = last_stack.best_attnP self.best_out = last_stack.best_out self.best_scores = last_stack.best_scores self.best_raw_scores = last_stack.best_raw_scores self.best_state = last_stack.best_dec_state
def __init__(self, model, batch_placeholder, max_len=None, force_bos=True, force_eos=True, get_tracked_outputs=lambda dec_state: [], crop_last_step=True, back_prop=True, swap_memory=False, **flags): """ Encode input sequence and iteratively decode output sequence. To be used in this fashion: trans = GreedyInference(model, {'inp':...}, ...) loss = cmon_do_something(trans.best_out, trans.best_scores) sess.run(loss) :type model: models.TranslateModel :param batch: a dictionary that contains symbolic tensor {'inp': input token ids, shape [batch_size,time]} :param max_len: maximum length of output sequence, defaults to 2*inp_len + 3 :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide. :param force_eos: if True, any token past initial EOS is guaranteed to be EOS :param get_tracked_outputs: callback that returns whatever tensor(s) you want to track on each time-step :param crop_last_step: if True, does not perform additional decode __after__ last eos ensures all tensors have equal time axis :param back_prop: see tf.while_loop back_prop param :param swap_memory: see tf.while_loop swap_memory param :param **flags: you can add any amount of tags that encode and decode understands. e.g. greedy=True or is_train=True """ print("Preparing GREEDY translate with params:", locals()) self.batch_placeholder = batch_placeholder self.get_tracked_outputs = get_tracked_outputs inp_len = batch_placeholder.get( 'inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) max_len = max_len if max_len is not None else (2 * inp_len + 3) first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) shape_invariants = nested_map( lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) # Actual decoding def should_continue_translating(*stack): stack = self.Stack(*stack) return tf.reduce_any(tf.less(stack.out_len, max_len)) & tf.reduce_any(stack.mask) def inference_step(*stack): stack = self.Stack(*stack) return self.step(model, stack, **flags) final_stack = tf.while_loop( cond=should_continue_translating, body=inference_step, loop_vars=first_stack, shape_invariants=shape_invariants, swap_memory=swap_memory, back_prop=back_prop, ) _, outputs, scores, _, dec_states, attnP, tracked_outputs = final_stack if crop_last_step: attnP = attnP[:, :-1] tracked_outputs = nested_map(lambda out: out[:, :-1], tracked_outputs) if force_eos: out_mask = infer_mask(outputs, model.out_voc.eos) outputs = tf.where(out_mask, outputs, tf.fill(tf.shape(outputs), model.out_voc.eos)) self.best_out = self.sample_out = outputs self.best_attnP = self.best_attnP = attnP self.best_scores = self.sample_scores = scores self.best_dec_states = self.dec_states = dec_states self.tracked_outputs = tracked_outputs
def switch(self, condition, state_on_true, state_on_false): """ Composes a new stack.best_dec_state out of new dec state when new_is_better and old dec state otherwise :param condition: a boolean condition vector of shape [batch_size] """ return nested_map(lambda x, y: tf.where(condition, x, y), state_on_true, state_on_false)