示例#1
0
    def create_initial_stack(self,
                             model,
                             batch_placeholder,
                             force_bos=True,
                             **flags):
        inp = batch_placeholder['inp']
        batch_size = tf.shape(inp)[0]

        initial_state = model.encode(batch_placeholder, **flags)
        initial_attnP = model.get_attnP(initial_state)[:, None]
        initial_tracked = nested_map(lambda x: x[:, None],
                                     self.get_tracked_outputs(initial_state))

        if force_bos:
            initial_outputs = tf.cast(
                tf.fill((batch_size, 1), model.out_voc.bos), inp.dtype)
            initial_state = model.decode(initial_state, initial_outputs[:, 0],
                                         **flags)
            second_attnP = model.get_attnP(initial_state)[:, None]
            initial_attnP = tf.concat([initial_attnP, second_attnP], axis=1)
            initial_tracked = nested_map(
                lambda x, y: tf.concat([x, y[:, None]], axis=1),
                initial_tracked,
                self.get_tracked_outputs(initial_state),
            )
        else:
            initial_outputs = tf.zeros((batch_size, 0), dtype=inp.dtype)

        initial_logp = tf.zeros([batch_size], dtype='float32')
        initial_mask = tf.ones_like([batch_size], dtype='bool')
        initial_len = tf.shape(initial_outputs)[1]

        return self.Stack(initial_len, initial_outputs, initial_logp,
                          initial_mask, initial_state, initial_attnP,
                          initial_tracked)
示例#2
0
    def step(self, model, stack, **flags):
        """
        :type model: models.TranslateModel
        :param stack: beam search stack
        :return: new beam search stack
        """
        out_len, out_seq, logp, mask, dec_states, attnP, tracked = stack

        # 1. sample
        batch_size = tf.shape(out_seq)[0]
        phony_slices = tf.range(batch_size)
        _, new_outputs, logp_next = model.sample(dec_states,
                                                 logp,
                                                 phony_slices,
                                                 k=1,
                                                 **flags)

        out_seq = tf.concat([out_seq, new_outputs], axis=1)
        logp = logp + logp_next[:, 0] * tf.cast(mask, 'float32')
        is_eos = tf.equal(new_outputs[:, 0], model.out_voc.eos)
        mask = tf.logical_and(mask, tf.logical_not(is_eos))

        # 2. decode
        new_states = model.decode(dec_states, new_outputs[:, 0], **flags)
        attnP = tf.concat([attnP, model.get_attnP(new_states)[:, None]],
                          axis=1)
        tracked = nested_map(
            lambda seq, new: tf.concat([seq, new[:, None]], axis=1), tracked,
            self.get_tracked_outputs(new_states))
        return self.Stack(out_len + 1, out_seq, logp, mask, new_states, attnP,
                          tracked)
示例#3
0
 def shuffle(self, dec_state, flat_indices):
     """
     Selects hypotheses from model decoder state by given indices.
     :param dec_state: a nested structure of tensors representing model state
     :param flat_indices: int32 vector of indices to select
     :returns: dec state elements for given flat_indices only
     """
     return nested_map(lambda var: tf.gather(var, flat_indices, axis=0), dec_state)
示例#4
0
    def __init__(self,
                 model,
                 batch_placeholder,
                 min_len=None,
                 max_len=None,
                 beam_size=3,
                 beam_spread=3,
                 force_bos=True,
                 if_no_eos='last',
                 back_prop=True,
                 swap_memory=False,
                 **flags):
        """
        Performs ingraph beam search for given input sequences (inp)
        Supports penalizing, pruning against best score and best score in beam (via beam_spread)
        :param model: something that implements TranslateModel
        :param batch_placeholder: whatever model can .encode,
            by default should be {'inp': int32 matrix [batch_size x time]}
        :param min_length: minimum valid output length. None means min_len=inp_len // 4 - 1
        :param max_len: maximum hypothesis length to consider,
            float('inf') means unlimited, None means max_len=2*inp_len + 3,
        :param beam_size: maximum number of hypotheses that can pass from one beam search step to another.
            The rest is pruned.
        :param beam_spread: maximum difference in score between a hypothesis and current best hypothesis.
            Anything below that is pruned.
        :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide.
        :param if_no_eos: if 'last', will return unfinished hypos if there are no finished hypos by max_len
                          elif 'initial', returns empty hypothesis
        :param back_prop: see tf.while_loop back_prop param
        :param swap_memory: see tf.while_loop swap_memory param

        :param **flags: whatever else you want to feed into model. This will be passed to encode, decode, etc.
            is_train - if True (default), enables dropouts and similar training-only stuff
            sampling_strategy - if "random", samples hypotheses proportionally to softmax(logits)
                                  otherwise(default) - takes top K hypotheses
            sampling_temperature -  if sampling_strategy == "random",
                performs sampling ~ softmax(logits/sampling_temperature)

        """
        print("Preparing BEAM SEARCH translate with params:", locals())
        assert if_no_eos in ['last', 'initial']
        assert np.isfinite(beam_spread) or max_len != float(
            'inf'), "Must set maximum length if beam_spread is infinite"
        # initialize fields
        self.batch_placeholder = batch_placeholder
        inp_len = batch_placeholder.get(
            'inp_len', infer_length(batch_placeholder['inp'],
                                    model.out_voc.eos))
        self.min_len = min_len if min_len is not None else inp_len // 4 - 1
        self.max_len = max_len if max_len is not None else 2 * inp_len + 3
        self.beam_size, self.beam_spread = beam_size, beam_spread
        self.force_bos, self.if_no_eos = force_bos, if_no_eos

        # actual beam search
        first_stack = self.create_initial_stack(model,
                                                batch_placeholder,
                                                force_bos=force_bos,
                                                **flags)
        shape_invariants = nested_map(
            lambda v: tf.TensorShape([None for _ in v.shape]), first_stack)

        def should_continue_translating(*stack):
            stack = self.BeamSearchStack(*stack)
            should_continue = self.should_extend_hypo(model, stack)
            return tf.reduce_any(should_continue)

        def expand_hypos(*stack):
            stack = self.BeamSearchStack(*stack)
            return self.beam_search_step(model, stack, **flags)

        last_stack = tf.while_loop(
            cond=should_continue_translating,
            body=expand_hypos,
            loop_vars=first_stack,
            shape_invariants=shape_invariants,
            back_prop=back_prop,
            swap_memory=swap_memory,
        )

        # crop unnecessary EOSes that occur if no hypothesis is updated on several last steps
        actual_length = infer_length(last_stack.best_out, model.out_voc.eos)
        max_length = tf.reduce_max(actual_length)
        last_stack = last_stack._replace(
            best_out=last_stack.best_out[:, :max_length])

        self.best_attnP = last_stack.best_attnP
        self.best_out = last_stack.best_out
        self.best_scores = last_stack.best_scores
        self.best_raw_scores = last_stack.best_raw_scores
        self.best_state = last_stack.best_dec_state
示例#5
0
    def __init__(self,
                 model,
                 batch_placeholder,
                 max_len=None,
                 force_bos=True,
                 force_eos=True,
                 get_tracked_outputs=lambda dec_state: [],
                 crop_last_step=True,
                 back_prop=True,
                 swap_memory=False,
                 **flags):
        """
        Encode input sequence and iteratively decode output sequence.
        To be used in this fashion:
        trans = GreedyInference(model, {'inp':...}, ...)
        loss = cmon_do_something(trans.best_out, trans.best_scores)
        sess.run(loss)

        :type model: models.TranslateModel
        :param batch: a dictionary that contains symbolic tensor {'inp': input token ids, shape [batch_size,time]}
        :param max_len: maximum length of output sequence, defaults to 2*inp_len + 3
        :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide.
        :param force_eos: if True, any token past initial EOS is guaranteed to be EOS
        :param get_tracked_outputs: callback that returns whatever tensor(s) you want to track on each time-step
        :param crop_last_step: if True, does not perform  additional decode __after__ last eos
                ensures all tensors have equal time axis
        :param back_prop: see tf.while_loop back_prop param
        :param swap_memory: see tf.while_loop swap_memory param
        :param **flags: you can add any amount of tags that encode and decode understands.
            e.g. greedy=True or is_train=True

        """

        print("Preparing GREEDY translate with params:", locals())
        self.batch_placeholder = batch_placeholder
        self.get_tracked_outputs = get_tracked_outputs

        inp_len = batch_placeholder.get(
            'inp_len', infer_length(batch_placeholder['inp'],
                                    model.out_voc.eos))
        max_len = max_len if max_len is not None else (2 * inp_len + 3)

        first_stack = self.create_initial_stack(model,
                                                batch_placeholder,
                                                force_bos=force_bos,
                                                **flags)
        shape_invariants = nested_map(
            lambda v: tf.TensorShape([None for _ in v.shape]), first_stack)

        # Actual decoding
        def should_continue_translating(*stack):
            stack = self.Stack(*stack)
            return tf.reduce_any(tf.less(stack.out_len,
                                         max_len)) & tf.reduce_any(stack.mask)

        def inference_step(*stack):
            stack = self.Stack(*stack)
            return self.step(model, stack, **flags)

        final_stack = tf.while_loop(
            cond=should_continue_translating,
            body=inference_step,
            loop_vars=first_stack,
            shape_invariants=shape_invariants,
            swap_memory=swap_memory,
            back_prop=back_prop,
        )

        _, outputs, scores, _, dec_states, attnP, tracked_outputs = final_stack
        if crop_last_step:
            attnP = attnP[:, :-1]
            tracked_outputs = nested_map(lambda out: out[:, :-1],
                                         tracked_outputs)

        if force_eos:
            out_mask = infer_mask(outputs, model.out_voc.eos)
            outputs = tf.where(out_mask, outputs,
                               tf.fill(tf.shape(outputs), model.out_voc.eos))

        self.best_out = self.sample_out = outputs
        self.best_attnP = self.best_attnP = attnP
        self.best_scores = self.sample_scores = scores
        self.best_dec_states = self.dec_states = dec_states
        self.tracked_outputs = tracked_outputs
示例#6
0
 def switch(self, condition, state_on_true, state_on_false):
     """
     Composes a new stack.best_dec_state out of new dec state when new_is_better and old dec state otherwise
     :param condition: a boolean condition vector of shape [batch_size]
     """
     return nested_map(lambda x, y: tf.where(condition, x, y), state_on_true, state_on_false)