def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def _RelPositionBias(query, abs_pos_emb): """Computes relative position bias for general cases.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)] abs_pos_emb = tf.reverse(abs_pos_emb, [0]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Convert to [B, N, T, T] # part1 term_bd_left = term_bd[:, :, :, :t] term_bd_left = tf.reverse(term_bd_left, [2, 3]) term_bd_left = RelShift(term_bd_left) # [B, N, T, T] term_bd_left = tf.reverse(term_bd_left, [2, 3]) # part 2 term_bd_right = term_bd[:, :, :, t - 1:] # [B, N, T, T] term_bd_right = RelShift(term_bd_right) # [lower triangle] mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0) # stitching togather return tf.where(mask > 0, term_bd_left, term_bd_right)
def Polynomial(x): """Polynomial function of x.""" p = self.params x0, y0 = p.start x1, y1 = p.limit assert x0 < x1, '%s must be < %s' % (x0, x1) x0 = tf.cast(x0, dtype=x.dtype) x1 = tf.cast(x1, dtype=x.dtype) y0 = tf.cast(y0, dtype=x.dtype) y1 = tf.cast(y1, dtype=x.dtype) f_x = ((x - x0) / (x1 - x0))**p.power y = y0 + f_x * (y1 - y0) return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(current_step, tf.float32) warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32) if p.decay_end is not None: current_step = tf.where(current_step < p.decay_end, current_step, tf.cast(p.decay_end, tf.float32)) return p.model_dim**-0.5 * tf.minimum( (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
def FProp(self, theta, current_step): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self._ref_step, best_step) f = self._cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self._ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self._cur_factor, new_factor)
def Clipped(): clip_ratio = state[0] min_value, max_value = self._GetCurrentMinMax( state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.where( clip_ratio >= 0.0, (lambda: tf.clip_by_value(x, min_value, max_value))(), (lambda: x)())
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.ops.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where( tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.ops.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where( tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def _MaybeFakeQuant(self, inputs, min_v, max_v, num_bits): p = self.params def Apply(): return tf.quantization.fake_quant_with_min_max_vars( inputs, min_v, max_v, num_bits=num_bits) if p.delay_start_steps != 0 and not self.do_eval: if p.delay_start_steps == -1: return inputs return tf.where(self.theta.global_step >= p.delay_start_steps, Apply(), inputs) else: return Apply()
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" params = self.params warmup_steps = tf.cast(params.decay_start * params.worker_replicas, tf.float32) current_step = tf.cast(current_step, tf.float32) if params.decay_end is not None: current_step = tf.where(current_step < params.decay_end, current_step, tf.cast(params.decay_end, tf.float32)) peak_learning_rate = (warmup_steps**-0.5) return (params.model_dim**-0.5) * tf.minimum( tf.minimum((current_step + 1), (current_step + 1)**-0.5), peak_learning_rate)
def QuantizeTensors(self, t_name, ts, eval_only=False): p = self.params # Always straddle a real zero point. if self.do_eval: # At eval/inference time, use the memorized range. # Important: Don't capture these variables in training mode so as to # avoid extra/unnecessary captures. min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') return [ self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits) for t in ts ] else: # At training time, use the batch calculated min/max. accumulator_name = self._GetAccumulatorNameForTensor(t_name) # Calculate min/max for all tensors. batch_min = 0.0 batch_max = 0.0 for t in ts: batch_min = tf.minimum(tf.reduce_min(t), batch_min) batch_max = tf.maximum(tf.reduce_max(t), batch_max) # New state. state1 = tf.stack([1.0, batch_min, batch_max]) self.accumulators[accumulator_name].Update(state1) # Results. ts_out = [] for i, t in enumerate(ts): if eval_only: # If only quantizing at eval time, still record ranges as above # but don't quantize. quant_t = t else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_t = self._MaybeFakeQuant(t, batch_min, batch_max, num_bits=p.bits) # TODO(laurenzo): Plumb quant_t_has_nans through state and report. quant_t_has_nans = tf.math.is_nan(quant_t) quant_t = tf.where(quant_t_has_nans, t, quant_t) ts_out.append(quant_t) summary_utils.histogram( '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t) return ts_out
def _RecordTensor(self, t_name): p = self.params if self.do_eval: return [] accumulator_name = self._GetAccumulatorNameForTensor(t_name) accumulator = self.accumulators[accumulator_name] min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') # Unpack state tensor. current_value = accumulator.GetValue() count = current_value[0] min_value = current_value[1] max_value = current_value[2] accumulator.Reset() def Ema(variable, value): return (1.0 - p.ema_decay) * (variable - value) # Note that small floating point issues can cause ranges that naturally # begin or end at zero to move slightly past, causing hard failures # downstream (checks that all ranges straddle zero). We therefore repeat # the straddling constraint here. return [ tf.assign( min_var, tf.minimum( 0., min_var - tf.where(count > 0., Ema(min_var, min_value), 0.))), tf.assign( max_var, tf.maximum( 0., max_var - tf.where(count > 0., Ema(max_var, max_value), 0.))), ]
def QuantizeWeight(self, w): p = self.params w_min = tf.reduce_min(w) w_max = tf.reduce_max(w) # NOTE: We force a small, non-zero range because otherwise, zero weights # can cause downstream inference engines to blow up. w_min = tf.minimum(w_min, -p.quantize_weight_epsilon) w_max = tf.maximum(w_max, p.quantize_weight_epsilon) quant_w = self._MaybeFakeQuant(w, w_min, w_max, num_bits=p.bits) if self.do_eval: return quant_w else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_w_has_nans = tf.math.is_nan(quant_w) return tf.where(quant_w_has_nans, w, quant_w)
def GetState(self, theta): """Gets the state from theta.""" p = self.params if p.is_inference: # State is not used for inference. Just return dummy. return tf.zeros([1], tf.float32) else: # Calculations/vars need to be float but these can be ints in the params. clip_end_step = tf.cast(p.clip_end_step, tf.float32) clip_start_step = tf.cast(p.clip_start_step, tf.float32) quant_start_step = tf.cast(p.quant_start_step, tf.float32) global_step = tf.cast(theta.global_step, tf.float32) # Will be negative if before clipping starts. clip_ratio = (tf.minimum(clip_end_step - clip_start_step, global_step - clip_start_step) / tf.maximum(1.0, clip_end_step - clip_start_step)) # Currently fq is either on (1.0) or off (-1.0). Progressive quantization # may later occupy 0..1.0. fq_ratio = tf.where(global_step < quant_start_step, -1.0, 1.0) return tf.stack([clip_ratio, fq_ratio])
def _GetGlobalGradScale(self, all_grad_norm, has_nan_or_inf): """Returns a scaling factor for all gradients according to their norm. In case there are NaN or Inf values the function will return 0.0. Args: all_grad_norm: A scalar represeting the total norm of all vars. has_nan_or_inf: A scalar of 0 or 1, indicating whether there is any NaN or Inf in input gradients. Returns: The gradient scale. 0 if gradient updates should be skipped for the step. """ p = self.params # Computes gradient's scale. grad_scale = tf.constant(1.0) if p.clip_gradient_norm_to_value: # If all_grad_norm > p.clip_gradient_norm_to_value, scales # all_grads so that the norm is 1.0. grad_scale = tf.minimum( 1.0, p.clip_gradient_norm_to_value / all_grad_norm) if p.grad_norm_to_clip_to_zero: # If all_grad_norm > p.grad_norm_to_clip_to_zero, treats # grad_scale as 0. This way, we ignore this step. grad_scale *= tf.cast(all_grad_norm < p.grad_norm_to_clip_to_zero, p.dtype) if p.grad_norm_tracker: grad_scale *= self.grad_norm_tracker.FPropDefaultTheta( all_grad_norm, has_nan_or_inf) # Force grad_scale to be 0 if there is any NaN or Inf in gradients. grad_scale = tf.where(has_nan_or_inf, 0.0, grad_scale) return grad_scale
def _StringToToken(self, tokstr): return tf.where(ops.token_in_vocab(tokstr, vocab=self._pieces), ops.vocab_token_to_id(tokstr, vocab=self._pieces), tf.broadcast_to(NO_TOKEN, tf.shape(tokstr)))
def _EncodeToIds(self, word): # Below: # * a token is a wordpiece ID. # * the tokens array will be merged in-place. # * the candidates array is an array of size len(tokens) - 1. # It contains the token for the merged wordpiece, if it exists, # -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]). # First, split into basic UTF-8 characters (letters). chars = tf.strings.unicode_split(word, 'UTF-8') tokens = self._StringToToken(chars) tokens = tf.where( tf.equal(tokens, NO_TOKEN), # Unseen character. tf.broadcast_to(self.unk_id, tf.shape(tokens)), tokens) # Create initial candidate list. candidates = tf.map_fn(self._MergeTokens, (tokens[:-1], tokens[1:]), dtype=tokens.dtype) def _ShouldMerge(unused_tokens, candidates): """Merge until not possible, or we abort early according to merge_prob.""" return tf.math.logical_and( tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)), tf.random.uniform([]) < self._merge_prob) def _MergeOneToken(tokens, i): return tf.expand_dims(self._MergeTokens( (tokens[i], tokens[i + 1])), axis=-1) def _MergeCandidates(tokens, candidates): """Merge in the reverse binary tree.""" best_id = tf.argmin(candidates, output_type=tf.int32) # Perform the merge at position best_id. tokens = tf.concat([ tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:] ], axis=0) # Recompute the merge candidates. # Only the neighbors of best_id need to be recomputed. empty = tf.zeros([0], dtype=candidates.dtype) def _MergeLeft(): return tf.concat([ candidates[:best_id - 1], _MergeOneToken(tokens, best_id - 1) ], axis=0) left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft) def _MergeRight(): return tf.concat([ _MergeOneToken(tokens, best_id), candidates[best_id + 2:] ], axis=0) right_candidates = tf.cond( tf.greater_equal(best_id, tf.size(tokens) - 1), lambda: empty, _MergeRight) candidates = tf.concat([left_candidates, right_candidates], axis=0) return tokens, candidates return tf.while_loop(_ShouldMerge, _MergeCandidates, (tokens, candidates), parallel_iterations=1, back_prop=False)[0]
def ApplyClippingWithState(self, state, x, start_cap=None, end_cap=None, bits=None): """Applies clipping. The start_cap, end_cap and bits can be set explicitly and take the default if None. Args: state: Clipping state. x: Tensor to clip. start_cap: Clipping value at the start of the ramp. end_cap: Clipping value at the end of the ramp. bits: Number of bits to quantize to. Returns: x with clipping applied. """ p = self.params if start_cap is None: start_cap = p.start_cap if end_cap is None: end_cap = p.end_cap if bits is None: bits = p.bits if p.is_inference: # For inference, we assume that both clipping and quantization have # saturated and just output a saturated quant op. min_value, max_value = self._GetCurrentMinMax( state, start_cap, end_cap, bits, fixate_to_end_state=True) # Note that the inference version uses the *_args variant, which requires # constants for min/max. The _GetCurrentMinMax will return (python) # constants if fixating. This is fragile but works around a Toco bug # if trying to run on the *_vars form because it can't seem to read # 0D tensors. This form has the benefit of blowing up at export time # if the min/max aren't constant. return _CopyShape( x, tf.quantization.fake_quant_with_min_max_args(x, min_value, max_value, num_bits=bits)) # Non-inference. def Clipped(): clip_ratio = state[0] min_value, max_value = self._GetCurrentMinMax( state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.where( clip_ratio >= 0.0, (lambda: tf.clip_by_value(x, min_value, max_value))(), (lambda: x)()) def Quantized(): min_value, max_value = self._GetCurrentMinMax( state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.quantization.fake_quant_with_min_max_vars(x, min_value, max_value, num_bits=bits) # Quantization will implicitly clip, so if we are in the quant phase, just # do that. Otherwise, clip (which will return identity if not in that # phase yet). fq_ratio = state[1] # return _CopyShape(x, Clipped()) return _CopyShape(x, tf.where(fq_ratio <= 0.0, Clipped(), Quantized()))
def beam_search_step(in_scores, in_atten_probs, in_best_scores, in_cumulative_scores, in_histories, cur_step, eos_id, num_beams, beam_size, num_hyps_per_beam, valid_eos_max_logit_delta=5.0, local_eos_threshold=-100.0, merge_paths=False, is_last_chunk=None, eoc_id=0): """A single step of beam search. Let "b" be the number of beams, "k" be the number hyps in each beam. This function supports values with dtypes tf.float32 or tf.bfloat16. The following data structures are allocated before the first decoding step and are passed along from cur step to the next step: Args: in_scores: A tensor of shape [b * k, vocab_size], where [i, ...] is the token score of the j-th hyps of the n-th beam. j = (i / k), and n = i % k in_atten_probs: A tensor of shape [b*k, s_len], where in_atten_probs[i, ...] is the attention probabilities over the source words of the j-th hyps of n-th beam (where j, and n are derived as above). in_best_scores: A vector of size [b], best scores of terminated hyps so far in each of the beams. in_cumulative_scores: A vector of size [b * k]. The cumulative score of each active hyp before the current step. in_histories: An int32 vector of size [b * k] containing hashes of the histories of each active hyp. If 'merge_paths' is enabled, the histories are used to identify hypotheses that are identical modulo epsilons (e.g. "a <eps> b" and "a b <eps>") and merge them. See 'update_histories' docstring for details. cur_step: Current step id. eos_id: Token id of the special end of sequence token. num_beams: Number of beams. beam_size: Search terminates if the delta between the scores of the active hyps. num_hyps_per_beam: Number of hyps in a beam. valid_eos_max_logit_delta: We allow </s> to terminate a hyp only if its logit is no more than 'valid_eos_max_logit_delta' away from the logit of the best candidate. local_eos_threshold: We allow </s> to terminate a hyp if the local score for </s> is greater than local_eos_threshold. merge_paths: If true, hyps which are identical when epsilons are removed will be combined into a single hyp. The probability for that combined hyp will be the sum of the probabilities of the component hyps. This can only be applied for epsilon-emitting models (RNN-T and NT). is_last_chunk: A tensor of shape [b * k, 1]. Used by neural transducer, determines whether the current hypothesis reaches the last chunk and should treat the next end-of-chunk symbol as end-of-sentence. eoc_id: int, the id of the end of chunk (a.k.a epsilon) token used by neural transducer models. Only relevant if 'merge_paths' is True or 'is_last_chunk' is provided. Returns: out_best_scores: A tensor of shape [b] of updated best scores for each of the beams. out_cumulative_scores: A tensor of shape [b * k]. The cumulative score of the new hyps after the current decoding step. out_scores: A tensor of shape [b * k] with scores of the token selected. out_eos_scores: A tensor of shape [b * k] with token scores for the EOS, in case the hyp was terminated, otherwise 0.0. out_hyps: A tensor of shape [b * k] with ids of the token selected. out_prev_hyps: A tensor of shape [b * k] with index to the previous hyps which was selected. out_done_hyps: A boolean tensor of shape [b * k] where value indicates if hyps was terminated. out_atten_probs: A tensor of shape [b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. out_eos_atten_probs: A tensor of shape [b * k, seq_len] which contains the attention probabilities over the source against word in the current hyp which was terminated. out_all_done: A scalar, whether decoding should terminate for all beams. out_histories: A tensor of shape [b * k] containing new history hashes for the active hypotheses. See 'update_histories' docstring for details. Raises: ValueError: if inputs are invalid. """ num_hyps_per_beam = int(num_hyps_per_beam) if num_hyps_per_beam <= 0: raise ValueError("num_hyps_per_beam = {} and must be > 0.".format( num_hyps_per_beam)) in_scores = tf.convert_to_tensor(in_scores) in_scores.shape.assert_has_rank(2) num_classes = in_scores.get_shape()[1] in_atten_probs = tf.convert_to_tensor(in_atten_probs) in_atten_probs.shape.assert_has_rank(2) in_best_scores = tf.convert_to_tensor(in_best_scores) in_best_scores.shape.assert_has_rank(1) in_cumulative_scores = tf.convert_to_tensor(in_cumulative_scores) in_cumulative_scores.shape.assert_has_rank(1) in_histories = tf.convert_to_tensor(in_histories) in_histories.shape.assert_has_rank(1) with tf.name_scope("beam_search_step"): # For k = num_hyps_per_beam # First step of beam search is to find the top tokens based on its score. # Normally we select k+1, where the extra +1 is to make sure we have k # non-eos tokens to select if EOS token is in the top-k. If path merging is # on, we actually need to select k+2; this ensures there are k+1 tokens left # after the merge, at least k of which are not EOS. # TODO(b/118644069): Avoid casts when there is a XLA op available that takes # in bfloat16. num_candidates_per_input_hyp = (num_hyps_per_beam + 2 if merge_paths else num_hyps_per_beam + 1) # [b * k, num_candidates_per_input_hyp] local_score_values, local_indices = xla_ops.top_k_with_unique( tf.cast(in_scores, tf.float32), k=num_candidates_per_input_hyp) local_score_values = tf.cast(local_score_values, in_scores.dtype) # Compute the global score which is sum of the local score, and the # cumulative scores for each of the hyps. # [b * k, num_candidates_per_input_hyp] global_score_values = local_score_values + tf.expand_dims( in_cumulative_scores, 1) values_dtype = local_score_values.dtype is_first_step = tf.cast(tf.equal(cur_step, 0), values_dtype) # Preprocessing to reorder the tensor from `mod` sharding to `div` so that # we can use matrix/vector operations to complete the beam search. # [b * k, num_candidates_per_input_hyp] global_score_values = reorder_tensor("mod_to_div", global_score_values, num_beams, num_hyps_per_beam) local_score_values = reorder_tensor("mod_to_div", local_score_values, num_beams, num_hyps_per_beam) local_indices = reorder_tensor("mod_to_div", local_indices, num_beams, num_hyps_per_beam, max_value=num_classes - 1) # [b * k, 1] histories = reorder_tensor("mod_to_div", tf.expand_dims(in_histories, 1), num_beams, num_hyps_per_beam) if is_last_chunk is None: is_last_chunk = tf.zeros([num_beams * num_hyps_per_beam, 1], tf.bool) else: is_last_chunk = tf.cast( reorder_tensor( "mod_to_div", tf.reshape(is_last_chunk, [num_beams * num_hyps_per_beam, 1]), num_beams, num_hyps_per_beam), tf.bool) # For the first step mask everything but the first row. # [num_hyps_per_beam] per_example_mask = tf.concat([ tf.constant([1.0], dtype=values_dtype), tf.zeros([num_hyps_per_beam - 1], dtype=values_dtype) ], 0) # [num_hyps_per_beam, num_beams] => [b*k, 1] mask = tf.reshape( tf.tile(per_example_mask, tf.expand_dims(num_beams, 0)), [-1, 1]) * is_first_step + (1.0 - is_first_step) local_score_values *= mask global_score_values *= mask # We add a large negative value for the unmasked values. per_example_additive_mask = tf.concat([ tf.constant([0.0], dtype=values_dtype), tf.constant(BEST_SCORES_INIT, shape=[num_hyps_per_beam - 1], dtype=values_dtype) ], 0) additive_mask = tf.reshape( tf.tile(per_example_additive_mask, tf.expand_dims(num_beams, 0)), [-1, 1]) * is_first_step local_score_values += additive_mask global_score_values += additive_mask if merge_paths: with tf.name_scope("merge_paths"): # Compute new history hashes for each hypothesis + new token. # [b * k, num_candidates_per_input_hyp] histories = update_histories(histories, local_indices, mask, epsilon_id=eoc_id) global_score_values, histories = merge_hyps( global_score_values, histories, mask, num_beams, num_hyps_per_beam) # As we keep num_candidates_per_input_hyp, we have a total of # num_candidates_per_input_hyp * k hyps active per example. num_candidate_hyps = num_candidates_per_input_hyp * num_hyps_per_beam batch_shape = [-1, num_candidate_hyps] # Reshape score values so that each row corresponds to a particular example. # [num_beams, num_candidate_hyps] global_score_values_batch = tf.reshape(global_score_values, batch_shape) # First for each beam: Find the top 2 * num_hyps_per_beam candidates. # The factor of 2 is to be able to process non EOS token ids in the case # where top scoring token for each hyps is EOS token. # [k * b, 2 * num_hyps_per_beam] _, candidates_indices_in_top_k = xla_ops.top_k_with_unique( tf.cast(global_score_values_batch, tf.float32), k=2 * num_hyps_per_beam) # Find the previous hyps of the candidate. We divide here by (k+1) to # identify which hyps this token came from. hyps_id = candidates_indices_in_top_k // num_candidates_per_input_hyp # Add in offset so that we can get the candidate index in the [b * k] space. offset = tf.expand_dims(tf.range(num_beams) * num_candidate_hyps, 1) flat_candidates_indices_in_top_k = tf.reshape( candidates_indices_in_top_k + offset, [-1]) flat_local_indices = tf.reshape(local_indices, [1, -1]) flat_token_scores = tf.reshape(local_score_values, [-1, 1]) flat_global_scores = tf.reshape(global_score_values, [-1, 1]) # Gather the token scores for each of 2*k candidates. We use tf.one_hot() # followed by a tf.matmul() to speedup gather on TPUs. total_num_candidates = num_beams * num_candidate_hyps token_scores_for_beam = tf.reshape( fast_gather(flat_token_scores, flat_candidates_indices_in_top_k, total_num_candidates), [num_beams, 2 * num_hyps_per_beam]) token_scores_for_beam_shape = tf.shape(token_scores_for_beam) global_scores_for_beam = tf.reshape( fast_gather(flat_global_scores, flat_candidates_indices_in_top_k, total_num_candidates), token_scores_for_beam_shape) # Local indices value's are between [0, vocab_size-1], hence we use the # slower version of gather. token_ids_for_beam = tf.reshape( fast_gather(flat_local_indices, flat_candidates_indices_in_top_k, total_num_candidates, max_value=num_classes - 1, axis=1), token_scores_for_beam_shape) # We have access to 2*num_hyps_per_beam hyps per beam. # We shrink back to num_hyps_per_beam that does not include EOS, and move # EOS that occurs in top-num_hyps_per_beam to the EOS done matrix. # To determine the threshold at which eos is allowed to terminate a hyp, # we need to know the maximum global score for that hyp with any additional # token. If path merging is *not* enabled, the global_score_values are # by construction in sorted order, so we can just look at its 0th column. If # path merging is enabled, the global scores of deleted (merged) hyps break # the sorted order, which means we have to do a full reduce_max. if merge_paths: max_global_score_per_input_hyp = tf.reduce_max(global_score_values, axis=1, keepdims=True) else: max_global_score_per_input_hyp = global_score_values[:, 0:1] # [num_beams * num_hyps_per_beam, 1] global_eos_threshold = (max_global_score_per_input_hyp - valid_eos_max_logit_delta) local_eos_threshold_tensor = local_eos_threshold * tf.ones_like( global_eos_threshold) # Find EOS in top num_hyps_per_beam token ids. We also treat EOC as EOS if # the model has indicated this is the last chunk. local_index_is_eos = tf.equal(local_indices, eos_id) local_index_is_last_chunk_eoc = tf.math.logical_and( tf.equal(local_indices, eoc_id), is_last_chunk) eos_mask = tf.math.logical_and( tf.math.logical_and( tf.math.logical_and( tf.greater( local_score_values, tf.tile(local_eos_threshold_tensor, [1, num_candidates_per_input_hyp])), tf.greater( global_score_values, tf.tile(global_eos_threshold, [1, num_candidates_per_input_hyp]))), tf.math.logical_or(local_index_is_eos, local_index_is_last_chunk_eoc)), tf.cast(mask, tf.bool)) end_hyps_bool_mask = tf.reshape(tf.reduce_any(eos_mask, 1), [-1, 1]) end_hyps_bool_mask = reorder_tensor("div_to_mod", end_hyps_bool_mask, num_beams, num_hyps_per_beam) eos_atten_probs = in_atten_probs * tf.cast(end_hyps_bool_mask, in_atten_probs.dtype) eos_atten_probs = tf.reshape(eos_atten_probs, [num_beams * num_hyps_per_beam, -1]) # A boolean tensor of shape [b * k] where value indicates if hyps was # terminated. out_done_hyps = tf.reshape(end_hyps_bool_mask, [-1]) # Scores for EOS token. eos_float_mask = tf.cast(eos_mask, values_dtype) eos_local_scores = eos_float_mask * local_score_values eos_additive_float_mask = (1.0 - eos_float_mask) * BEST_SCORES_INIT eos_local_scores += eos_additive_float_mask out_eos_scores = tf.reshape(tf.reduce_max(eos_local_scores, 1), [-1, 1]) out_eos_scores = tf.reshape( reorder_tensor("div_to_mod", out_eos_scores, num_beams, num_hyps_per_beam), [-1]) # A tensor of shape [b] of updated best scores for each of the beams. eos_global_scores = eos_float_mask * global_score_values eos_global_scores += eos_additive_float_mask best_scores = tf.reduce_max( tf.reshape(eos_global_scores, [num_beams, -1]), 1) # Following operations are to finds the top num_hyps_per_beam that are # active. # Active ones are the ones that do not correspond to EOS termination. # We keep num_hyps_per_beam * 2 in case every hyps is terminated by EOS id. # Top K with eos removed. non_eos_mask = tf.not_equal(token_ids_for_beam, eos_id) num_candidate_hyps = num_hyps_per_beam * 2 * num_beams index = tf.where( non_eos_mask, tf.reshape(tf.range(num_candidate_hyps, dtype=tf.int32), token_scores_for_beam_shape), num_candidate_hyps * tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape)) # Unrolled TopK. sorted_indices = [] # Finds the first num_hyps_per_beam unmasked indexes and stores them in # concated_index (shape: [num_beams, num_candidate_hyps]) # This is done by iteratively record the min index in each row, and reset # it to the max, so that next iteration reduce_min returns the 2nd minimum # index. for _ in range(num_hyps_per_beam): min_index = tf.reshape(tf.reduce_min(index, [1]), [num_beams, 1]) sorted_indices.append(min_index) # Replace position with num_candidate_hyps value. index = tf.where( tf.equal(index, min_index), num_candidate_hyps * tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape), index) # Post processing ops to output expected tensors. concated_sorted_indices = tf.concat(sorted_indices, 1) flat_sorted_indices = tf.reshape(concated_sorted_indices, [-1]) # A tensor of shape [b * k] with scores of the token selected. out_scores = tf.reshape( fast_gather(tf.reshape(token_scores_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_scores = tf.reshape( reorder_tensor("div_to_mod", out_scores, num_beams, num_hyps_per_beam), [-1]) # Gather the updated histories of selected hypotheses if path merging is # enabled. Otherwise, the histories are unused, so just output in_histories. if merge_paths: flat_histories = tf.reshape(histories, [-1, 1]) # [num_beams, 2 * num_hyps_per_beam] histories_for_beam = tf.reshape( fast_gather(flat_histories, flat_candidates_indices_in_top_k, total_num_candidates), token_scores_for_beam_shape) out_histories = tf.reshape( fast_gather(tf.reshape(histories_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_histories = tf.reshape( reorder_tensor("div_to_mod", out_histories, num_beams, num_hyps_per_beam), [-1]) else: out_histories = in_histories prev_hyps_ids = tf.reshape( tf.reshape( fast_gather(tf.reshape(hyps_id, [1, -1]), flat_sorted_indices, num_candidate_hyps, max_value=num_hyps_per_beam, axis=1), [num_beams, -1]) * num_beams + tf.expand_dims(tf.range(num_beams), 1), [-1, 1]) prev_hyps_ids = reorder_tensor("div_to_mod", prev_hyps_ids, num_beams, num_hyps_per_beam, max_value=num_hyps_per_beam) # A tensor of shape [b * k] with index to the previous hyps which was # selected. out_prev_hyps = tf.reshape(prev_hyps_ids, [-1]) # A tensor of shape [b * k, seq_len] which contain the attention # probabilities over the source words against word in the previous hyps. out_atten_probs = tf.reshape( fast_gather(in_atten_probs, out_prev_hyps, num_beams * num_hyps_per_beam), [num_beams * num_hyps_per_beam, -1]) sorted_top_k_ids = fast_gather(tf.reshape(token_ids_for_beam, [1, -1]), flat_sorted_indices, num_candidate_hyps, max_value=num_classes - 1, axis=1) sorted_top_k_ids = reorder_tensor("div_to_mod", sorted_top_k_ids, num_beams, num_hyps_per_beam, max_value=num_classes - 1, axis=1) # A tensor of shape [b * k] with ids of the token selected. out_hyps = tf.reshape(sorted_top_k_ids, [-1]) # A tensor of shape [b * k]. The cumulative score of the selected hyps after # the current decoding step. out_cumulative_scores = tf.reshape( fast_gather(tf.reshape(global_scores_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_cumulative_scores = tf.reshape( reorder_tensor("div_to_mod", out_cumulative_scores, num_beams, num_hyps_per_beam), [-1]) out_best_scores = tf.maximum(best_scores, in_best_scores) # A scalar, whether decoding should terminate for all beams. out_all_done = tf.reshape( tf.math.logical_not( tf.reduce_any( tf.greater( out_cumulative_scores, tf.reshape( tf.tile( tf.reshape(out_best_scores - beam_size, [-1, 1]), [1, num_hyps_per_beam]), [-1])))), []) return (out_best_scores, out_cumulative_scores, out_scores, out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, eos_atten_probs, out_all_done, out_histories)
def _ComputePaddings(ids, eos_id): is_eos = tf.cast(tf.equal(ids, eos_id), tf.int32) # eos_in_prefix[i, j] = any(ids[i, k] == eos_id for k in range(j)) eos_in_prefix = tf.cumsum(is_eos, axis=-1, exclusive=True) return tf.where(tf.equal(eos_in_prefix, 0), tf.zeros_like(ids), tf.ones_like(ids))
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, num_hyps_per_beam=1) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, allow_implicit_capture=True) result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _CreateCanvasAndTargets(self, batch): # pyformat: disable """Create the canvas and targets. Args: batch: A `.NestedMap`. - src: A `.NestedMap`. - ids: The source ids, ends in <eos>. - paddings: The source paddings. - tgt: A `.NestedMap`. - ids: The target ids, ends in <eos>. - paddings: The target paddings. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices (i.e., use these indices to tf.gather_nd the log-probs). Optional, only during training. - target_weights: The target weights. Optional, only during training. """ # pyformat: enable p = self.params if not self.do_eval: # Sample our src and tgt canvas. src_descriptor = self._SampleCanvasAndTargets( batch.src.ids, batch.src.paddings) tgt_descriptor = self._SampleCanvasAndTargets( batch.tgt.ids, batch.tgt.paddings) # Offset the src ids (to unshare embeddings between src/tgt). Note, we # only offset the canvas ids, but we do not offset the vocab ids. This # will result in unshared embeddings, but shared softmax. This is due to # GPU/TPU memory limitations, empirically it is known that unsharing # everything results in better performance. vocab_size = p.decoder.softmax.num_classes src_descriptor.canvas = tf.where( tf.equal(src_descriptor.canvas_paddings, 0), src_descriptor.canvas + vocab_size, src_descriptor.canvas) # Offset the tgt indices (need shift according to src length). batch_size = py_utils.GetShape(batch.src.ids)[0] # `target_batch` is a [num_targets, batch_size] tensor where each row # identifies which batch the target belongs to. Note the observation that, # tf.reduce_sum(target_batch, 1) == 1 \forall rows. target_batch = tf.cast( tf.equal( tf.expand_dims(tf.range(batch_size), 0), tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)), tf.int32) src_lens = tf.cast( tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32) # `tgt_offset` is shape [num_targets] where each entry corresponds to the # offset needed for that target (due to the source length). tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1)) # We shift the tgt slot without touching the batch or vocab. tgt_descriptor.target_indices += tf.concat([ tf.zeros_like(tgt_offset), tgt_offset, tf.zeros_like(tgt_offset) ], 1) # The canvas is simply the sequence-level concat of the src and tgt. canvas, canvas_paddings = insertion.SequenceConcat( src_descriptor.canvas, src_descriptor.canvas_paddings, tgt_descriptor.canvas, tgt_descriptor.canvas_paddings) target_indices = tf.concat( [src_descriptor.target_indices, tgt_descriptor.target_indices], 0) target_weights = tf.concat( [src_descriptor.target_weights, tgt_descriptor.target_weights], 0) return py_utils.NestedMap(canvas=canvas, canvas_paddings=canvas_paddings, target_indices=target_indices, target_weights=target_weights)