def create_initial_stack(self, model, batch_placeholder, force_bos=False, **flags): inp = batch_placeholder['inp'] batch_size = tf.shape(inp)[0] initial_state = model.encode(batch_placeholder, **flags) initial_attnP = model.get_attnP(initial_state)[:, None] initial_tracked = nested_map(lambda x: x[:, None], self.get_tracked_outputs(initial_state)) if force_bos: initial_outputs = tf.cast(tf.fill((batch_size, 1), model.out_voc.bos), inp.dtype) initial_state = model.decode(initial_state, initial_outputs[:, 0], **flags) second_attnP = model.get_attnP(initial_state)[:, None] initial_attnP = tf.concat([initial_attnP, second_attnP], axis=1) initial_tracked = nested_map(lambda x, y: tf.concat([x, y[:, None]], axis=1), initial_tracked, self.get_tracked_outputs(initial_state),) else: initial_outputs = tf.zeros((batch_size, 0), dtype=inp.dtype) initial_scores = tf.zeros([batch_size], dtype='float32') initial_finished = tf.zeros_like([batch_size], dtype='bool') initial_len = tf.shape(initial_outputs)[1] return self.Stack(initial_outputs, initial_len, initial_scores, initial_finished, initial_state, initial_attnP, initial_tracked)
def shuffle_beam(self, model, stack, flat_indices): """ Selects hypotheses by index from entire BeamSearchStack Note: this method assumes that both stack and flat_indices are sorted by sample index (i.e. first are indices for input0 are, then indices for input1, then 2, ... then input[batch_size-1] """ n_hypos = tf.shape(stack.out)[0] batch_size = tf.shape(stack.best_out)[0] # compute new slices: # step 1: get index of inptut sequence (in batch) for each hypothesis in flat_indices sample_ids_for_slices = tf.gather(hypo_to_batch_index(n_hypos, stack.slices), flat_indices) # step 2: compute how many hypos per flat_indices n_hypos_per_sample = tf.bincount(sample_ids_for_slices, minlength=batch_size, maxlength=batch_size) # step 3: infer slice start indices new_slices = tf.cumsum(n_hypos_per_sample, exclusive=True) # shuffle everything else return stack._replace( out=tf.gather(stack.out, flat_indices), scores=tf.gather(stack.scores, flat_indices), raw_scores=tf.gather(stack.raw_scores, flat_indices), attnP=tf.gather(stack.attnP, flat_indices), dec_state=model.shuffle(stack.dec_state, flat_indices), ext=nested_map(lambda x: tf.gather(x, flat_indices), stack.ext), slices=new_slices, )
def greedy_step(self, model, stack, **flags): """ :type model: lib.task.seq2seq.inference.translate_model.TranslateModel :param stack: beam search stack :return: new beam search stack """ out_seq, out_len, scores, finished, dec_states, attnP, tracked = stack # 1. sample batch_size = tf.shape(out_seq)[0] phony_slices = tf.range(batch_size) _, new_outputs, logp_next = model.sample(dec_states, scores, phony_slices, k=1, **flags) out_seq = tf.concat([out_seq, new_outputs], axis=1) scores = scores + logp_next[:, 0] * tf.cast(~finished, 'float32') is_eos = tf.equal(new_outputs[:, 0], model.out_voc.eos) finished = tf.logical_or(finished, is_eos) # 2. decode new_states = model.decode(dec_states, new_outputs[:, 0], **flags) attnP = tf.concat([attnP, model.get_attnP(new_states)[:, None]], axis=1) tracked = nested_map(lambda seq, new: tf.concat([seq, new[:, None]], axis=1), tracked, self.get_tracked_outputs(new_states) ) return self.Stack(out_seq, out_len + 1, scores, finished, new_states, attnP, tracked)
def shuffle(self, dec_state, hypo_indices): """ Selects hypotheses from model decoder state by given indices. :param dec_state: a nested structure of tensors representing model state :param hypo_indices: int32 vector of indices to select :returns: dec state elements for given flat_indices only """ return nested_map(lambda x: tf.gather(x, hypo_indices), dec_state)
def __init__(self, model, batch_placeholder, max_len=None, force_bos=False, force_eos=True, get_tracked_outputs=lambda dec_state: [], crop_last_step=True, back_prop=True, swap_memory=False, **flags): self.batch_placeholder = batch_placeholder self.get_tracked_outputs = get_tracked_outputs inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) max_len = max_len if max_len is not None else (2 * inp_len + 3) first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) # Actual decoding def should_continue_translating(*stack): stack = self.Stack(*stack) return tf.reduce_any(tf.less(stack.out_len, max_len)) & tf.reduce_any(~stack.finished) def inference_step(*stack): stack = self.Stack(*stack) return self.greedy_step(model, stack, **flags) final_stack = tf.while_loop( cond=should_continue_translating, body=inference_step, loop_vars=first_stack, shape_invariants=shape_invariants, swap_memory=swap_memory, back_prop=back_prop, ) outputs, _, scores, _, dec_states, attnP, tracked_outputs = final_stack if crop_last_step: attnP = attnP[:, :-1] tracked_outputs = nested_map(lambda out: out[:, :-1], tracked_outputs) if force_eos: out_mask = infer_mask(outputs, model.out_voc.eos) outputs = tf.where(out_mask, outputs, tf.fill(tf.shape(outputs), model.out_voc.eos)) self.best_out = outputs self.best_attnP = attnP self.best_scores = scores self.dec_states = dec_states self.tracked_outputs = tracked_outputs
def __init__(self, model, batch_placeholder, min_len=None, max_len=None, beam_size=12, beam_spread=3, beam_spread_raw=None, force_bos=False, if_no_eos='last', back_prop=True, swap_memory=False, **flags ): assert if_no_eos in ['last', 'initial'] assert np.isfinite(beam_spread) or max_len != float('inf'), "Must set maximum length if beam_spread is infinite" # initialize fields self.batch_placeholder = batch_placeholder inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) self.min_len = min_len if min_len is not None else inp_len // 4 - 1 self.max_len = max_len if max_len is not None else 2 * inp_len + 3 self.beam_size, self.beam_spread = beam_size, beam_spread if beam_spread_raw is None: self.beam_spread_raw = beam_spread else: self.beam_spread_raw = beam_spread_raw self.force_bos, self.if_no_eos = force_bos, if_no_eos # actual beam search first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) def should_continue_translating(*stack): stack = self.Stack(*stack) should_continue = self.should_continue_translating(model, stack) return tf.reduce_any(should_continue) def expand_hypos(*stack): return self.beam_search_step(model, self.Stack(*stack), **flags) last_stack = tf.while_loop( cond=should_continue_translating, body=expand_hypos, loop_vars=first_stack, shape_invariants=shape_invariants, back_prop=back_prop, swap_memory=swap_memory, ) # crop unnecessary EOSes that occur if no hypothesis is updated on several last steps actual_length = infer_length(last_stack.best_out, model.out_voc.eos) max_length = tf.reduce_max(actual_length) last_stack = last_stack._replace(best_out=last_stack.best_out[:, :max_length]) self.best_out = last_stack.best_out self.best_attnP = last_stack.best_attnP self.best_scores = last_stack.best_scores self.best_raw_scores = last_stack.best_raw_scores self.best_state = last_stack.best_dec_state
def switch(self, condition, state_on_true, state_on_false): """ Composes a new stack.best_dec_state out of new dec state when new_is_better and old dec state otherwise :param condition: a boolean condition vector of shape [batch_size] """ return nested_map(lambda x, y: tf.where(condition, x, y), state_on_true, state_on_false)
def __init__(self, model, is_train=False): """ An object that finds most likely sequence of inserts :type model: lib.models.Transformer """ self.model = model self.batch_ph = make_batch_placeholder( model.make_feed_dict(model._get_batch_sample())) self.k_best = tf.placeholder('int32', []) self.hypo_base_logprobs = tf.placeholder('float32', [None]) # [batch_size] # a mask of allowed tokens not_special = np.array([(i not in model.out_voc._default_token_ix) for i in range(len(model.out_voc))], dtype=np.bool) self.allowed_tokens = tf.placeholder_with_default( not_special, shape=[len(model.out_voc)], ) # ^-- [voc_size] # step 1: precompute encoder outputs for all unique input lines enc = model.encode(self.batch_ph, is_train) self.cached_enc_state = {} for key, value in enc.items(): self.cached_enc_state[key] = tf.Variable(tf.zeros([], value.dtype), validate_shape=False, trainable=False, name=value.name[:-2] + '_cached') self.cached_enc_state[key].set_shape(value.shape) self.compute_enc_state = list( nested_flatten( nested_map( lambda var, val: tf.assign(var, val, validate_shape=False), self.cached_enc_state, enc))) # step 2: assemble decoder outputs for each input # there may be several out hypos for the same inp line. Up to beam_size hypos to be exact. out_to_inp_ix = tf.zeros([tf.shape(self.batch_ph['out'])[0]], dtype=tf.int64) enc_reordered = { k: tf.gather(v, out_to_inp_ix) for k, v in self.cached_enc_state.items() } # step 3: compute logits and action log-probs for inserting tokens and finishing logp = model.compute_action_logprobs(self.batch_ph, is_train, enc=enc_reordered) ################### # insert operation hypo_logprobs_insert = self.hypo_base_logprobs[:, None, None] + logp['insert'] # ^-- [batch, position, token] hypo_logprobs_insert -= 1e9 * tf.to_float( tf.logical_not(self.allowed_tokens)) best_inserts_flat = tf.nn.top_k(tf.reshape(hypo_logprobs_insert, [-1]), k=self.k_best, sorted=True) batch_size, max_len = tf.shape(self.batch_ph['out'])[0], tf.shape( self.batch_ph['out'])[1] voc_size = len(model.out_voc) best_hypo_ix = best_inserts_flat.indices // (max_len * voc_size) best_insert_pos = (best_inserts_flat.indices // voc_size) % max_len best_token_ix = best_inserts_flat.indices % voc_size best_insert_logp = best_inserts_flat.values self.insert_kbest = [ best_hypo_ix, best_insert_pos, best_token_ix, best_insert_logp ] ################## # eos operation self.finished_hypo_logprobs = self.hypo_base_logprobs + logp['finish']
def __init__(self, model, sess=None, optimized_variables=None, name=None, verbose=False, is_train=True, initialize=True, sampler_opts=None, optimizer_opts=None, grad_clip=0, **kwargs): """ An imperative trainer is an object that performs training on batches. Works out-of-graph (in python). It is hard-coded to do one thing - sample-based training - but it does that thing well. :type model: lib.models.Transformer :param sess: tf session to use. tf.get_default_session by default, create new if no default. """ self.model = model self.name = name = name or 'trainer_' + model.name self.sess = sess = sess or tf.get_default_session( ) or tf.InteractiveSession() self.verbose = verbose with tf.name_scope(self.name), tf.variable_scope(self.name) as scope: optimized_variables = optimized_variables or get_optimized_variables( model, verbose) self.optimized_variables = optimized_variables self.step = tf.train.get_or_create_global_step(sess.graph) # gradient accumulators (for virtual batch training) self.accumulated_grads = [ tf.Variable(tf.zeros_like(w), trainable=False, name=w.name[:-2] + '_acc') for w in optimized_variables ] self.accumulated_num_batches = tf.Variable( tf.zeros(()), trainable=False, name='num_batches_since_update') ############ # step 1: precompute encoder state for all unique input lines self.encoder_batch_ph = self.model.make_encoder_batch_ph() enc = model.encode(self.encoder_batch_ph, is_train) self.cached_enc_state, self.compute_enc_state = make_symbolic_cache( enc) ############ # step 2: path_sampler samples a batch of trajectories (sequences of inserts) # it also caches encoder state for efficiency self.path_sampler = SampleReferenceInserts( model, **(sampler_opts or {}), enc_state=self.cached_enc_state) self.cached_enc_state = nested_map(tf.stop_gradient, self.cached_enc_state) self.cached_grad_wrt_enc = nested_map( lambda v: tf.Variable(tf.zeros([]), validate_shape=False, trainable=False, name=v.name[:-2] + '_cached_grad'), self.cached_enc_state) self.reset_cached_grad_wrt_enc = nested_map( lambda acc, tensor: tf.assign( acc, tf.zeros_like(tensor), validate_shape=False), self.cached_grad_wrt_enc, self.cached_enc_state) self.fetch_before_batch = tf.group( [self.reset_cached_grad_wrt_enc]) ############ # step 3: a trajectory is split into slices (for memory efficiency), # for each slice we compute dL/d_w_dec and dL/d_enc_state self.slice_ph = { 'out': tf.placeholder('int32', [None, None]), 'out_len': tf.placeholder('int32', [None]), 'out_to_inp_indices': tf.placeholder('int32', [None]), 'ref_len': tf.placeholder('int32', [None]), 'ref_inserts': tf.placeholder('int64', [None, 3]), 'chosen_inserts': tf.placeholder('int64', [None, 3]), } loss_on_slice, counters_on_slice = self.get_loss_and_counters( self.slice_ph, self.cached_enc_state, is_train=is_train, **kwargs) flat_enc_keys = sorted(self.cached_enc_state.keys()) flat_enc_cache = list(self.cached_enc_state[k] for k in flat_enc_keys) flat_accumulated_grad_wrt_enc = [ self.cached_grad_wrt_enc[k] for k in flat_enc_keys ] loss_grads_on_slice = tf.gradients( loss_on_slice, optimized_variables + flat_enc_cache) weight_and_enc_grad_accumulators = self.accumulated_grads + flat_accumulated_grad_wrt_enc self.update_grads_on_slice = [ tf.assign_add(grad_acc, grad) for grad_acc, grad in zip( weight_and_enc_grad_accumulators, loss_grads_on_slice) if grad is not None ] # ^-- sess.run-ning this will update gradients w.r.t. decoder weights and encoder state # accumulators for metrics self.accumulated_counters = nested_map( lambda v: tf.Variable(tf.zeros(v.shape, v.dtype), trainable=False), counters_on_slice) self.update_counters_on_slice = nested_map( tf.assign_add, self.accumulated_counters, counters_on_slice) self.fetch_on_slice = tf.group( [self.update_grads_on_slice, self.update_counters_on_slice]) ############ # step 4: once we're finished with all slices in one batch, it's time we compute the remaining gradients # dL/d_w_enc = dL/d_enc_state * d_enc_state/d_w_enc encoder_state = model.encode(self.encoder_batch_ph, is_train=is_train) flat_encoder_state = [encoder_state[k] for k in flat_enc_keys] loss_grads_after_slice = tf.gradients( flat_encoder_state, optimized_variables, grad_ys=flat_accumulated_grad_wrt_enc) self.update_grads_after_batch = [ tf.assign_add(grad_acc, grad) for grad_acc, grad in zip( self.accumulated_grads, loss_grads_after_slice) if grad is not None ] self.fetch_after_batch = tf.group([ self.update_grads_after_batch, tf.assign_add(self.accumulated_num_batches, 1) ]) ############ # step 5: after one or several batches, we use the accumulated gradients to perform optimization step, # compute metrics for summary and then reset all buffers with tf.control_dependencies([ tf.assert_positive( self.accumulated_num_batches, message='Accumulate gradients over at least one ' 'full batch before averaging them') ]): loss_denominator = self.get_denominator( self.accumulated_counters) self.grads_avg = [ grad_acc / loss_denominator for grad_acc in self.accumulated_grads ] self.opt = self.get_optimizer(self.step, **(optimizer_opts or {})) if grad_clip: grads, self.grads_global_norm = tf.clip_by_global_norm( self.grads_avg, grad_clip) else: grads, self.grads_global_norm = self.grads_avg, tf.global_norm( self.grads_avg) self.apply_gradients = tf.group( self.opt.apply_gradients(zip(grads, optimized_variables), global_step=self.step)) self.reset_gradients = tf.group( tf.variables_initializer(self.accumulated_grads + [self.accumulated_num_batches])) self.compute_metrics = self.aggregate_metrics_from_counters( self.accumulated_counters) self.reset_counters = tf.variables_initializer( list(nested_flatten(self.accumulated_counters))) if initialize: sess.run([ self.reset_gradients, self.reset_counters, tf.assign(self.step, 1) ]) remaining_utility_variables = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope.name) initialize_uninitialized_variables( sess=sess, var_list=remaining_utility_variables)