def _inverse_pth_root_graph(self, epsilon): graph = tf.Graph() with graph.as_default(): exponent_t = tf.reshape( tf.placeholder(dtype=tf.float32, name="exponent", shape=None), []) # Apply exponent multiplier. exponent_t = exponent_t * self._exponent_multiplier input_t = tf.placeholder(dtype=tf.float32, name="input", shape=None) # For p = 2, 4 or 8, we use the iterative Newton-Schur method for # computing the inverse-pth root. either_p_2_4_8 = tf.logical_or( tf.logical_or(tf.equal(-1.0 / exponent_t, 2), tf.equal(-1.0 / exponent_t, 4)), tf.equal(-1.0 / exponent_t, 8)) # 4096 is the larger dimension SVD is tractable for. greater_than_4096 = tf.greater(tf.shape(input_t)[0], 4096) run_specialized_iterative_method = tf.logical_and( greater_than_4096, either_p_2_4_8) specialized_fn = functools.partial( self._specialized_inverse_pth_root, input_t, exponent_t, epsilon) generalized_fn = functools.partial( self._generalized_inverse_pth_root, input_t, exponent_t, epsilon) output, diff = tf.cond(run_specialized_iterative_method, specialized_fn, generalized_fn) tf.identity(output, "output") tf.identity(tf.cast(diff, tf.float32), "diff") return graph.as_graph_def().SerializeToString()
def infinite_repeat(body_fn, infeed_queue): """Builds infinite loop. Args: body_fn: a Python function that builds the loop body. infeed_queue: if not None, the infeed queue from which to append a tuple of arguments as inputs to condition. Returns: The final values of the loop-carried tensors. """ def to_list(x): if isinstance(x, (list, tuple)): return list(x) else: return [x] def body_fn_wrapper(i, *args): return [i + 1] + to_list(body_fn(*args)) outputs = training_loop.while_loop( # Infinite loop. Using only tf.constant(True) causes the XLA graph to # appear to be stateful. lambda i, *args: tf.logical_or(tf.constant(True), i < 10000), body_fn_wrapper, inputs=[0], infeed_queue=infeed_queue) outputs = to_list(outputs) if len(outputs) == 1: # Returns the Op rather than an empty list. return outputs[0].op else: return outputs[1:]
def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask, dec_state, t): del tgt_pos, tgt_segment_id [buf] = dec_state if tgt_id.shape == (self.batch_size, self.beam_size): buf = inplace_ops.alias_inplace_update(buf, t, tgt_id) else: div = int(tgt_id.shape[1] // self.beam_size) for i, x_i in enumerate(tf.split(tgt_id, div, 1)): buf = inplace_ops.alias_inplace_update(buf, t + i, x_i) buf1 = tf.transpose(buf, [1, 0, 2]) buf1 = tf.reshape(buf1, [self.batch_size, self.max_steps * self.beam_size]) # select next_tgt_id as a function of previous target tokens if self.rule == '+1': next_tgt_id = (tgt_id + 1) next_tgt_id %= self.vocab_size elif self.rule == 'sum': # sum over all previous tokens in tgt_mask next_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(tgt_mask, tf.int32)) next_tgt_id %= self.vocab_size elif self.rule == 'fib': # select last token according to tgt_mask m = tgt_mask m *= tf.cast( tf.equal(tf.cumsum(m, -1), tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype) last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32)) next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size # with a lower probably add extra +1 to the correct next_tgt_id n = self.vocab_size logits = 5 * tf.one_hot(next_tgt_id % n, n) logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n) logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n) logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n) logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n) # increase eos_score if current tgt_id contains 9 eos_id = 0 tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9), tf.equal((tgt_id // 10) % 10, 9)) logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot( eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32)) # tie-breaking -- lower token id wins a little bit tie = np.arange(0., 1., 1. / n) tie /= tie.sum() logits -= tie logits = tf.nn.log_softmax(logits) dec_state = [buf] return logits, dec_state
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam=1) new_step_ids = tf.arg_max(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.logical_or(done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent
def maybe_update_masks(): with tf.name_scope(self._spec.name): is_step_within_pruning_range = tf.logical_and( tf.greater_equal(self._global_step, self._spec.begin_pruning_step), # If end_pruning_step is negative, keep pruning forever! tf.logical_or( tf.less_equal(self._global_step, self._spec.end_pruning_step), tf.less(self._spec.end_pruning_step, 0))) is_pruning_step = tf.less_equal( tf.add(self._last_update_step, self._spec.pruning_frequency), self._global_step) return tf.logical_and(is_step_within_pruning_range, is_pruning_step)
def ScaleGradients(self, var_grads, gradient_adjuster=None): """Scales gradients according to training params. Args: var_grads: a `.NestedMap` whose values are (var, grad) pairs. gradient_adjuster: if not None, a function that mutates a given var_grads. Returns: A `.NestedMap` containing: - has_nan_or_inf: a scalar of 0 or 1, indicating whether there is any NaN or Inf in input gradients. - final_var_grads: a `.NestedMap` whose values are (var, grad) pairs, where gradients have already been scaled. - grad_scale: the gradient scale. 0 if gradient updates should be skipped for the step. (Optional, only returned in case global norm clipping is used.) """ p = self.params # Computes gradients' norm and adds their summaries. Note that all_grad_norm # may be nan, which may cause grad_scale to be nan. for name, vg in var_grads.FlattenItems(): summary_utils.AddNormSummary(name + '/' + p.name, py_utils.NestedMap(s=vg)) all_grad_norm = tf.sqrt( py_utils.SumSquared([ g for (_, g) in py_utils.NestedMap(child=var_grads).Flatten() ])) all_var_norm = tf.sqrt( py_utils.SumSquared([ v for (v, _) in py_utils.NestedMap(child=var_grads).Flatten() ])) grad_norm_is_nan_or_inf = tf.logical_or(tf.is_nan(all_grad_norm), tf.is_inf(all_grad_norm)) # Optional gradient adjustment. Note that this happens after computing # all_grad_norm. if gradient_adjuster is not None: tf.logging.info('gradient_adjuster=%s', gradient_adjuster) var_grads = gradient_adjuster(var_grads) # Handles NaN/Inf gradients. has_nan_or_inf = py_utils.HasNanOrInfGradient(var_grads) # Grad norm can still be inf even if none of the individual grad is inf. has_nan_or_inf = tf.logical_or(has_nan_or_inf, grad_norm_is_nan_or_inf) return_values = py_utils.NestedMap() if p.clip_gradient_single_norm_to_value: # Currently using both types of clipping simultaneously is unsupported. if p.clip_gradient_norm_to_value: raise ValueError( 'Cannot use clip_gradient_single_norm_to_value=%f and ' 'clip_gradient_norm_to_value=%f.' % (p.clip_gradient_single_norm_to_value, p.clip_gradient_norm_to_value)) final_var_grads = py_utils.ApplyGradNormCliping( var_grads, p.clip_gradient_single_norm_to_value) else: grad_scale = self._GetGlobalGradScale(all_grad_norm, has_nan_or_inf) self._AddEvalMetric('grad_norm/all', all_grad_norm, tf.constant(1.0)) self._AddEvalMetric('var_norm/all', all_var_norm, tf.constant(1.0)) self._AddEvalMetric('grad_scale_all', grad_scale, tf.constant(1.0)) final_var_grads = py_utils.ApplyGradMultiplier( var_grads, grad_scale) return_values.grad_scale = grad_scale return_values.has_nan_or_inf = has_nan_or_inf return_values.final_var_grads = final_var_grads return return_values
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) out_states.consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(out_states.consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot(label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) bs_results.log_probs = tf.log(probs) return bs_results, out_states