def create_initial_stack(self, model, batch_placeholder, force_bos=False, **flags):
        inp = batch_placeholder['inp']
        batch_size = tf.shape(inp)[0]

        initial_state = model.encode(batch_placeholder, **flags)
        initial_attnP = model.get_attnP(initial_state)[:, None]
        initial_tracked = nested_map(lambda x: x[:, None], self.get_tracked_outputs(initial_state))

        if force_bos:
            initial_outputs = tf.cast(tf.fill((batch_size, 1), model.out_voc.bos), inp.dtype)
            initial_state = model.decode(initial_state, initial_outputs[:, 0], **flags)
            second_attnP = model.get_attnP(initial_state)[:, None]
            initial_attnP = tf.concat([initial_attnP, second_attnP], axis=1)
            initial_tracked = nested_map(lambda x, y: tf.concat([x, y[:, None]], axis=1),
                                         initial_tracked,
                                         self.get_tracked_outputs(initial_state),)
        else:
            initial_outputs = tf.zeros((batch_size, 0), dtype=inp.dtype)

        initial_scores = tf.zeros([batch_size], dtype='float32')
        initial_finished = tf.zeros_like([batch_size], dtype='bool')
        initial_len = tf.shape(initial_outputs)[1]

        return self.Stack(initial_outputs, initial_len, initial_scores, initial_finished,
                          initial_state, initial_attnP, initial_tracked)
    def shuffle_beam(self, model, stack, flat_indices):
        """
        Selects hypotheses by index from entire BeamSearchStack
        Note: this method assumes that both stack and flat_indices are sorted by sample index
        (i.e. first are indices for input0 are, then indices for input1, then 2, ... then input[batch_size-1]
        """
        n_hypos = tf.shape(stack.out)[0]
        batch_size = tf.shape(stack.best_out)[0]

        # compute new slices:
        # step 1: get index of inptut sequence (in batch) for each hypothesis in flat_indices
        sample_ids_for_slices = tf.gather(hypo_to_batch_index(n_hypos, stack.slices), flat_indices)
        # step 2: compute how many hypos per flat_indices
        n_hypos_per_sample = tf.bincount(sample_ids_for_slices, minlength=batch_size, maxlength=batch_size)
        # step 3: infer slice start indices
        new_slices = tf.cumsum(n_hypos_per_sample, exclusive=True)

        # shuffle everything else
        return stack._replace(
            out=tf.gather(stack.out, flat_indices),
            scores=tf.gather(stack.scores, flat_indices),
            raw_scores=tf.gather(stack.raw_scores, flat_indices),
            attnP=tf.gather(stack.attnP, flat_indices),
            dec_state=model.shuffle(stack.dec_state, flat_indices),
            ext=nested_map(lambda x: tf.gather(x, flat_indices), stack.ext),
            slices=new_slices,
        )
    def greedy_step(self, model, stack, **flags):
        """
        :type model: lib.task.seq2seq.inference.translate_model.TranslateModel
        :param stack: beam search stack
        :return: new beam search stack
        """
        out_seq, out_len, scores, finished, dec_states, attnP, tracked = stack

        # 1. sample
        batch_size = tf.shape(out_seq)[0]
        phony_slices = tf.range(batch_size)
        _, new_outputs, logp_next = model.sample(dec_states, scores, phony_slices, k=1, **flags)

        out_seq = tf.concat([out_seq, new_outputs], axis=1)
        scores = scores + logp_next[:, 0] * tf.cast(~finished, 'float32')
        is_eos = tf.equal(new_outputs[:, 0], model.out_voc.eos)
        finished = tf.logical_or(finished, is_eos)

        # 2. decode
        new_states = model.decode(dec_states, new_outputs[:, 0], **flags)
        attnP = tf.concat([attnP, model.get_attnP(new_states)[:, None]], axis=1)
        tracked = nested_map(lambda seq, new: tf.concat([seq, new[:, None]], axis=1),
                             tracked, self.get_tracked_outputs(new_states)
                             )
        return self.Stack(out_seq, out_len + 1, scores, finished, new_states, attnP, tracked)
 def shuffle(self, dec_state, hypo_indices):
     """
     Selects hypotheses from model decoder state by given indices.
     :param dec_state: a nested structure of tensors representing model state
     :param hypo_indices: int32 vector of indices to select
     :returns: dec state elements for given flat_indices only
     """
     return nested_map(lambda x: tf.gather(x, hypo_indices), dec_state)
    def __init__(self, model, batch_placeholder, max_len=None, force_bos=False, force_eos=True,
                 get_tracked_outputs=lambda dec_state: [], crop_last_step=True,
                 back_prop=True, swap_memory=False, **flags):
        self.batch_placeholder = batch_placeholder
        self.get_tracked_outputs = get_tracked_outputs

        inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos))
        max_len = max_len if max_len is not None else (2 * inp_len + 3)

        first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags)
        shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack)

        # Actual decoding
        def should_continue_translating(*stack):
            stack = self.Stack(*stack)
            return tf.reduce_any(tf.less(stack.out_len, max_len)) & tf.reduce_any(~stack.finished)

        def inference_step(*stack):
            stack = self.Stack(*stack)
            return self.greedy_step(model, stack, **flags)

        final_stack = tf.while_loop(
            cond=should_continue_translating,
            body=inference_step,
            loop_vars=first_stack,
            shape_invariants=shape_invariants,
            swap_memory=swap_memory,
            back_prop=back_prop,
        )

        outputs, _, scores, _, dec_states, attnP, tracked_outputs = final_stack
        if crop_last_step:
            attnP = attnP[:, :-1]
            tracked_outputs = nested_map(lambda out: out[:, :-1], tracked_outputs)

        if force_eos:
            out_mask = infer_mask(outputs, model.out_voc.eos)
            outputs = tf.where(out_mask, outputs, tf.fill(tf.shape(outputs), model.out_voc.eos))

        self.best_out = outputs
        self.best_attnP = attnP
        self.best_scores = scores
        self.dec_states = dec_states
        self.tracked_outputs = tracked_outputs
    def __init__(self, model, batch_placeholder, min_len=None, max_len=None,
                 beam_size=12, beam_spread=3, beam_spread_raw=None, force_bos=False,
                 if_no_eos='last', back_prop=True, swap_memory=False, **flags
                 ):
        assert if_no_eos in ['last', 'initial']
        assert np.isfinite(beam_spread) or max_len != float('inf'), "Must set maximum length if beam_spread is infinite"
        # initialize fields
        self.batch_placeholder = batch_placeholder
        inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos))
        self.min_len = min_len if min_len is not None else inp_len // 4 - 1
        self.max_len = max_len if max_len is not None else 2 * inp_len + 3
        self.beam_size, self.beam_spread = beam_size, beam_spread
        if beam_spread_raw is None:
            self.beam_spread_raw = beam_spread
        else:
            self.beam_spread_raw = beam_spread_raw
        self.force_bos, self.if_no_eos = force_bos, if_no_eos

        # actual beam search
        first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags)
        shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack)

        def should_continue_translating(*stack):
            stack = self.Stack(*stack)
            should_continue = self.should_continue_translating(model, stack)
            return tf.reduce_any(should_continue)

        def expand_hypos(*stack):
            return self.beam_search_step(model, self.Stack(*stack), **flags)

        last_stack = tf.while_loop(
            cond=should_continue_translating,
            body=expand_hypos,
            loop_vars=first_stack,
            shape_invariants=shape_invariants,
            back_prop=back_prop,
            swap_memory=swap_memory,
        )

        # crop unnecessary EOSes that occur if no hypothesis is updated on several last steps
        actual_length = infer_length(last_stack.best_out, model.out_voc.eos)
        max_length = tf.reduce_max(actual_length)
        last_stack = last_stack._replace(best_out=last_stack.best_out[:, :max_length])

        self.best_out = last_stack.best_out
        self.best_attnP = last_stack.best_attnP
        self.best_scores = last_stack.best_scores
        self.best_raw_scores = last_stack.best_raw_scores
        self.best_state = last_stack.best_dec_state
 def switch(self, condition, state_on_true, state_on_false):
     """
     Composes a new stack.best_dec_state out of new dec state when new_is_better and old dec state otherwise
     :param condition: a boolean condition vector of shape [batch_size]
     """
     return nested_map(lambda x, y: tf.where(condition, x, y), state_on_true, state_on_false)
Exemplo n.º 8
0
    def __init__(self, model, is_train=False):
        """
        An object that finds most likely sequence of inserts
        :type model: lib.models.Transformer
        """

        self.model = model
        self.batch_ph = make_batch_placeholder(
            model.make_feed_dict(model._get_batch_sample()))

        self.k_best = tf.placeholder('int32', [])
        self.hypo_base_logprobs = tf.placeholder('float32',
                                                 [None])  # [batch_size]

        # a mask of allowed tokens
        not_special = np.array([(i not in model.out_voc._default_token_ix)
                                for i in range(len(model.out_voc))],
                               dtype=np.bool)
        self.allowed_tokens = tf.placeholder_with_default(
            not_special,
            shape=[len(model.out_voc)],
        )
        # ^-- [voc_size]

        # step 1: precompute encoder outputs for all unique input lines
        enc = model.encode(self.batch_ph, is_train)
        self.cached_enc_state = {}
        for key, value in enc.items():
            self.cached_enc_state[key] = tf.Variable(tf.zeros([], value.dtype),
                                                     validate_shape=False,
                                                     trainable=False,
                                                     name=value.name[:-2] +
                                                     '_cached')
            self.cached_enc_state[key].set_shape(value.shape)

        self.compute_enc_state = list(
            nested_flatten(
                nested_map(
                    lambda var, val: tf.assign(var, val, validate_shape=False),
                    self.cached_enc_state, enc)))

        # step 2: assemble decoder outputs for each input
        # there may be several out hypos for the same inp line. Up to beam_size hypos to be exact.
        out_to_inp_ix = tf.zeros([tf.shape(self.batch_ph['out'])[0]],
                                 dtype=tf.int64)
        enc_reordered = {
            k: tf.gather(v, out_to_inp_ix)
            for k, v in self.cached_enc_state.items()
        }

        # step 3: compute logits and action log-probs for inserting tokens and finishing
        logp = model.compute_action_logprobs(self.batch_ph,
                                             is_train,
                                             enc=enc_reordered)

        ###################
        # insert operation
        hypo_logprobs_insert = self.hypo_base_logprobs[:, None,
                                                       None] + logp['insert']
        # ^-- [batch, position, token]
        hypo_logprobs_insert -= 1e9 * tf.to_float(
            tf.logical_not(self.allowed_tokens))
        best_inserts_flat = tf.nn.top_k(tf.reshape(hypo_logprobs_insert, [-1]),
                                        k=self.k_best,
                                        sorted=True)

        batch_size, max_len = tf.shape(self.batch_ph['out'])[0], tf.shape(
            self.batch_ph['out'])[1]
        voc_size = len(model.out_voc)

        best_hypo_ix = best_inserts_flat.indices // (max_len * voc_size)
        best_insert_pos = (best_inserts_flat.indices // voc_size) % max_len
        best_token_ix = best_inserts_flat.indices % voc_size
        best_insert_logp = best_inserts_flat.values
        self.insert_kbest = [
            best_hypo_ix, best_insert_pos, best_token_ix, best_insert_logp
        ]

        ##################
        # eos operation
        self.finished_hypo_logprobs = self.hypo_base_logprobs + logp['finish']
Exemplo n.º 9
0
    def __init__(self,
                 model,
                 sess=None,
                 optimized_variables=None,
                 name=None,
                 verbose=False,
                 is_train=True,
                 initialize=True,
                 sampler_opts=None,
                 optimizer_opts=None,
                 grad_clip=0,
                 **kwargs):
        """
        An imperative trainer is an object that performs training on batches. Works out-of-graph (in python).
        It is hard-coded to do one thing - sample-based training - but it does that thing well.
        :type model: lib.models.Transformer
        :param sess: tf session to use. tf.get_default_session by default, create new if no default.
        """
        self.model = model
        self.name = name = name or 'trainer_' + model.name
        self.sess = sess = sess or tf.get_default_session(
        ) or tf.InteractiveSession()
        self.verbose = verbose

        with tf.name_scope(self.name), tf.variable_scope(self.name) as scope:
            optimized_variables = optimized_variables or get_optimized_variables(
                model, verbose)
            self.optimized_variables = optimized_variables
            self.step = tf.train.get_or_create_global_step(sess.graph)

            # gradient accumulators (for virtual batch training)
            self.accumulated_grads = [
                tf.Variable(tf.zeros_like(w),
                            trainable=False,
                            name=w.name[:-2] + '_acc')
                for w in optimized_variables
            ]
            self.accumulated_num_batches = tf.Variable(
                tf.zeros(()), trainable=False, name='num_batches_since_update')

            ############
            # step 1: precompute encoder state for all unique input lines
            self.encoder_batch_ph = self.model.make_encoder_batch_ph()

            enc = model.encode(self.encoder_batch_ph, is_train)
            self.cached_enc_state, self.compute_enc_state = make_symbolic_cache(
                enc)

            ############
            # step 2: path_sampler samples a batch of trajectories (sequences of inserts)
            # it also caches encoder state for efficiency
            self.path_sampler = SampleReferenceInserts(
                model, **(sampler_opts or {}), enc_state=self.cached_enc_state)
            self.cached_enc_state = nested_map(tf.stop_gradient,
                                               self.cached_enc_state)
            self.cached_grad_wrt_enc = nested_map(
                lambda v: tf.Variable(tf.zeros([]),
                                      validate_shape=False,
                                      trainable=False,
                                      name=v.name[:-2] + '_cached_grad'),
                self.cached_enc_state)

            self.reset_cached_grad_wrt_enc = nested_map(
                lambda acc, tensor: tf.assign(
                    acc, tf.zeros_like(tensor), validate_shape=False),
                self.cached_grad_wrt_enc, self.cached_enc_state)
            self.fetch_before_batch = tf.group(
                [self.reset_cached_grad_wrt_enc])
            ############
            # step 3: a trajectory is split into slices (for memory efficiency),
            # for each slice we compute dL/d_w_dec and dL/d_enc_state
            self.slice_ph = {
                'out': tf.placeholder('int32', [None, None]),
                'out_len': tf.placeholder('int32', [None]),
                'out_to_inp_indices': tf.placeholder('int32', [None]),
                'ref_len': tf.placeholder('int32', [None]),
                'ref_inserts': tf.placeholder('int64', [None, 3]),
                'chosen_inserts': tf.placeholder('int64', [None, 3]),
            }
            loss_on_slice, counters_on_slice = self.get_loss_and_counters(
                self.slice_ph,
                self.cached_enc_state,
                is_train=is_train,
                **kwargs)

            flat_enc_keys = sorted(self.cached_enc_state.keys())
            flat_enc_cache = list(self.cached_enc_state[k]
                                  for k in flat_enc_keys)
            flat_accumulated_grad_wrt_enc = [
                self.cached_grad_wrt_enc[k] for k in flat_enc_keys
            ]

            loss_grads_on_slice = tf.gradients(
                loss_on_slice, optimized_variables + flat_enc_cache)
            weight_and_enc_grad_accumulators = self.accumulated_grads + flat_accumulated_grad_wrt_enc
            self.update_grads_on_slice = [
                tf.assign_add(grad_acc, grad) for grad_acc, grad in zip(
                    weight_and_enc_grad_accumulators, loss_grads_on_slice)
                if grad is not None
            ]
            # ^-- sess.run-ning this will update gradients w.r.t. decoder weights and encoder state

            # accumulators for metrics
            self.accumulated_counters = nested_map(
                lambda v: tf.Variable(tf.zeros(v.shape, v.dtype),
                                      trainable=False), counters_on_slice)
            self.update_counters_on_slice = nested_map(
                tf.assign_add, self.accumulated_counters, counters_on_slice)
            self.fetch_on_slice = tf.group(
                [self.update_grads_on_slice, self.update_counters_on_slice])

            ############
            # step 4: once we're finished with all slices in one batch, it's time we compute the remaining gradients
            # dL/d_w_enc = dL/d_enc_state * d_enc_state/d_w_enc

            encoder_state = model.encode(self.encoder_batch_ph,
                                         is_train=is_train)
            flat_encoder_state = [encoder_state[k] for k in flat_enc_keys]
            loss_grads_after_slice = tf.gradients(
                flat_encoder_state,
                optimized_variables,
                grad_ys=flat_accumulated_grad_wrt_enc)
            self.update_grads_after_batch = [
                tf.assign_add(grad_acc, grad) for grad_acc, grad in zip(
                    self.accumulated_grads, loss_grads_after_slice)
                if grad is not None
            ]

            self.fetch_after_batch = tf.group([
                self.update_grads_after_batch,
                tf.assign_add(self.accumulated_num_batches, 1)
            ])

            ############
            # step 5: after one or several batches, we use the accumulated gradients to perform optimization step,
            # compute metrics for summary and then reset all buffers

            with tf.control_dependencies([
                    tf.assert_positive(
                        self.accumulated_num_batches,
                        message='Accumulate gradients over at least one '
                        'full batch before averaging them')
            ]):
                loss_denominator = self.get_denominator(
                    self.accumulated_counters)
                self.grads_avg = [
                    grad_acc / loss_denominator
                    for grad_acc in self.accumulated_grads
                ]

            self.opt = self.get_optimizer(self.step, **(optimizer_opts or {}))

            if grad_clip:
                grads, self.grads_global_norm = tf.clip_by_global_norm(
                    self.grads_avg, grad_clip)
            else:
                grads, self.grads_global_norm = self.grads_avg, tf.global_norm(
                    self.grads_avg)

            self.apply_gradients = tf.group(
                self.opt.apply_gradients(zip(grads, optimized_variables),
                                         global_step=self.step))
            self.reset_gradients = tf.group(
                tf.variables_initializer(self.accumulated_grads +
                                         [self.accumulated_num_batches]))

            self.compute_metrics = self.aggregate_metrics_from_counters(
                self.accumulated_counters)
            self.reset_counters = tf.variables_initializer(
                list(nested_flatten(self.accumulated_counters)))

            if initialize:
                sess.run([
                    self.reset_gradients, self.reset_counters,
                    tf.assign(self.step, 1)
                ])
                remaining_utility_variables = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope.name)
                initialize_uninitialized_variables(
                    sess=sess, var_list=remaining_utility_variables)