def Apply(self, lr, var_grad): p = self.params def _Acc(vg): """Updating accumulators.""" v, g = vg with tf.variable_scope(v.op.name): _, a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a) var_grad = var_grad.Transform(_Acc) def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ]) return tf.cond( tf.equal(tf.mod(self.theta.global_step, p.accum_steps), p.accum_steps - 1), _ApplyAndReset, lambda: tf.group(tf.no_op()))
def _GetBucketKey(self, features, filtered): """Returns a the bucket key for a given input.""" # The token ids are not truncated if and only if it ends with padding # or the last id is EOS. src_fits = tf.math.logical_or( tf.math.equal(features.src.ids_indicator[-1], 0), tf.math.equal(features.src.ids[-1], self._src_tokenizer.eos_id)) tgt_fits = tf.math.logical_or( tf.math.equal(features.tgt.ids_indicator[-1], 0), tf.math.equal(features.tgt.labels[-1], self._tgt_tokenizer.eos_id)) # We return the max of sourcec or target sequence length if and only if both # src and tgt fit. Otherwise we return a key of -1 to filter out this input. def _MaxLen(): src_len = tf.cast(tf.math.reduce_sum(features.src.ids_indicator), dtype=tf.int32) tgt_len = tf.cast(tf.math.reduce_sum(features.tgt.ids_indicator), dtype=tf.int32) return tf.math.maximum(src_len, tgt_len) filtered = tf.math.logical_or( filtered, tf.math.logical_not(tf.math.logical_and(src_fits, tgt_fits))) return tf.cond(filtered, lambda: -1, _MaxLen)
def conditional_mask_update_op(self): def maybe_update_masks(): with tf.name_scope(self._spec.name): is_step_within_pruning_range = tf.logical_and( tf.greater_equal(self._global_step, self._spec.begin_pruning_step), # If end_pruning_step is negative, keep pruning forever! tf.logical_or( tf.less_equal(self._global_step, self._spec.end_pruning_step), tf.less(self._spec.end_pruning_step, 0))) is_pruning_step = tf.less_equal( tf.add(self._last_update_step, self._spec.pruning_frequency), self._global_step) return tf.logical_and(is_step_within_pruning_range, is_pruning_step) def mask_update_op(): return self.mask_update_op() def no_update_op(): return tf.no_op() return tf.cond(maybe_update_masks(), mask_update_op, no_update_op)
def Processor(source_id, record): """Parses a record, which is a line of text.""" task_id = self._GetTaskIds(source_id) if self.params.input_file_type == 'tsv': def _ApplyMass(task_id): mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32) return tf.reduce_any(tf.equal(task_id, mass_task_ids)) def _MASSInput(): src, filtered = self._ReadRecordTsvSingleColumn(record) return self._ProcessMASSInput(source_id, src), filtered def _SingleInput(): src, tgt, filtered = self._ReadRecordTsv(record) return self._ProcessSingleInput(source_id, src, tgt), filtered if self.params.single_column_input: # For monolingual input, MASS is applied by default. # If mass_task_ids is specified, only apply MASS to specified tasks. if self.params.mass_task_ids is not None: cond = _ApplyMass(task_id) features, filtered = tf.cond(cond, _MASSInput, _SingleInput) else: features, filtered = _MASSInput() else: features, filtered = _SingleInput() else: src, tgt = self._ReadRecordSentencePairProto(record) filtered = tf.constant(False, dtype=tf.bool) features = self._ProcessSingleInput(source_id, src, tgt) return features, self._GetBucketKey(features, filtered)
def bucket_fn(num): # Drops record if num[0] is odd. return tf.cond(tf.equal(tf.math.floormod(num[0], 2), 0), lambda: 1, lambda: -tf.cast(num[0], tf.int32))
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent def NoApplyBias(): """No-op. Return original log_probs and consistent.""" return bs_results.log_probs, states.consistent log_probs, consistent = tf.cond( tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias) bs_results.log_probs = log_probs out_states.consistent = consistent return bs_results, out_states
def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)
def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint)
def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
def ProcessFeatures(self, features): """Process extracted features. Args: features: A dict of extracted Tensors from the records. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): # Filter out extracted features from other extractors. filtered_features = {} if self.params.record_type == 'TEXT': # Text extractors only produce {'line': record} and their # FeatureMap() is empty, so don't do any filtering. filtered_features = features else: filtered_keys = e.FeatureMap().keys() | e.ContextMap().keys() filtered_features = { k: v for k, v in features.items() if k in filtered_keys } try: if self.params.batched_input: extracted = e.ExtractBatch(filtered_features) else: extracted = e.Extract(filtered_features) except Exception as exc: # pylint:disable=bare-except # Raise exception with context about which extractor failed. raise RuntimeError('Failed running extractor ' f'{e.params.name}. ' 'See above exception for details.') from exc with tf.name_scope('filter'): if self.params.batched_input: bucket = e.FilterBatch(extracted) else: bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype as the output of Preprocess. """ shapes = self.Shape() rets = [] for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()): if shape.is_fully_defined(): rets += [tf.zeros(dtype=dtype, shape=shape)] else: rets += [tf.zeros(dtype=dtype, shape=[])] # Our best guess. return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope(preprocessor.params.name): if self.params.batched_input: extracted = preprocessor.TransformBatchedFeatures(extracted) else: extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. # # Note that the NullLike branch may return tensors with shapes different # from self.Shape(). final_output = tf.cond( tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def ExtractUsingExtractors(self, record): """Extracts Tensors from a tf.Example record using self.extractors. Args: record: A tf.Example input to pass to tf.parse_single_example. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ feature_map = {} context_map = {} self._extractors.Transform( lambda e: feature_map.update(e.FeatureMap())) if self.params.record_type == 'SEQUENCE_EXAMPLE': self._extractors.Transform( lambda e: context_map.update(e.ContextMap())) if self.params.record_type not in _PARSING_FUNCTIONS: raise ValueError('Invalid record_type: {}'.format( self.params.record_type)) parsing_fn = _PARSING_FUNCTIONS[self.params.record_type] if self.params.record_type == 'SEQUENCE_EXAMPLE': features = parsing_fn(record, feature_map, context_map) else: features = parsing_fn(record, feature_map) def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): extracted = e.Extract(features) with tf.name_scope('filter'): bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype and shape as the output of Preprocess. """ shapes = self.Shape() rets = [ tf.zeros(dtype=dtype, shape=shape) for (dtype, shape) in zip(self.DType().Flatten(), shapes.Flatten()) ] return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope( preprocessor.params.name): extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. final_output = tf.cond(tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def _InputBatch(self): np.random.seed(1) bs, sl = 10, 7 src_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_labels = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_weights = tf.constant(np.ones(shape=[bs, sl], dtype=np.float32)) src_paddings = tf.zeros([bs, sl]) tgt_paddings = tf.zeros([bs, sl]) ret = py_utils.NestedMap() ret.src = py_utils.NestedMap() ret.tgt = py_utils.NestedMap() if self.params.split: src_ids = tf.split(src_ids, 2, 0) src_paddings = tf.split(src_paddings, 2, 0) tgt_ids = tf.split(tgt_ids, 2, 0) tgt_labels = tf.split(tgt_labels, 2, 0) tgt_paddings = tf.split(tgt_paddings, 2, 0) tgt_weights = tf.split(tgt_weights, 2, 0) ret.src.ids = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: src_ids[0], lambda: src_ids[1]) ret.src.paddings = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: src_paddings[0], lambda: src_paddings[1]) ret.tgt.ids = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_ids[0], lambda: tgt_ids[1]) ret.tgt.labels = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_labels[0], lambda: tgt_labels[1]) ret.tgt.paddings = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_paddings[0], lambda: tgt_paddings[1]) ret.tgt.weights = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_weights[0], lambda: tgt_weights[1]) else: ret.src.ids = src_ids ret.src.paddings = src_paddings ret.tgt.ids = tgt_ids ret.tgt.labels = tgt_labels ret.tgt.paddings = tgt_paddings ret.tgt.weights = tgt_weights return ret
def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x)
def ProcessFeatures(self, features): """Process extracted features. Args: features: A dict of extracted Tensors from the records. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): extracted = e.Extract(features) with tf.name_scope('filter'): bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype as the output of Preprocess. """ shapes = self.Shape() rets = [] for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()): if shape.is_fully_defined(): rets += [tf.zeros(dtype=dtype, shape=shape)] else: rets += [tf.zeros(dtype=dtype, shape=[])] # Our best guess. return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope(preprocessor.params.name): extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. # # Note that the NullLike branch may return tensors with shapes different # from self.Shape(). final_output = tf.cond( tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def Callback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) if biased: labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent def NoApplyBias(): """No-op. Return original log_probs and consistent.""" return bs_results.log_probs, states.consistent log_probs, consistent = tf.cond( tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias) bs_results.log_probs = log_probs out_states.consistent = consistent if stochastic: log_probs = bs_results.log_probs def PerturbedLogProbs(): # STEP 1: Perform top-k filtering. This is done as a performance # optimization of avoiding sorting the entire `log_probs`, which is # prohibitively slow. top_k = tf.math.top_k(log_probs, k, sorted=True) # shape: [tgt_batch, k] top_k_log_probs = top_k.values # shape: [tgt_batch, k] top_k_ids = top_k.indices # STEP 2: Perform top-p filtering. # shape: [tgt_batch] top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.) top_p_threshold = TileForBeamAndFlatten(top_p_threshold) # shape: [tgt_batch, k] filtered_top_k_log_probs = _KeepTopP( top_k_log_probs, top_p_threshold) # STEP 3: Perturb cumulative log-probs. # shape: [tgt_batch, 1] last_cumulative_log_probs = states.cumulative_log_probs # shape: [tgt_batch, 1] last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs # Compute cumulative log-probs of the current step. # shape: [tgt_batch, k] cumulative_log_probs = (last_cumulative_log_probs + filtered_top_k_log_probs) # Perturb cumulative log-probs by Gumbel noises under the condition # that the max of the new perturbed log-probs is equal to # perturbed_cumulative_log_probs of the previous step. # shape: [tgt_batch, k] new_perturbed_cumulative_log_probs = _SampleGumbelWithMax( cumulative_log_probs, last_perturbed_cumulative_log_probs, encoder_outputs.stochastic_beam_search.seed, time_step, encoder_outputs.stochastic_beam_search.src_ids, encoder_outputs.stochastic_beam_search.src_paddings) # STEP 4: Compute updated log_probs. This step is necessary because # the output of PreBeamSearchStepCallback must be "per-step" # log-probs, whereas so far "cumulative" log-probs have been computed. # shape: [tgt_batch, k] updated_top_k_log_probs = ( new_perturbed_cumulative_log_probs - last_perturbed_cumulative_log_probs) # Convert to the shape [tgt_batch, vocab_size]. updated_log_probs = tf.fill( tf.shape(log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=log_probs.dtype)) updated_log_probs = _BatchScatter(updated_log_probs, top_k_ids, updated_top_k_log_probs) return (updated_log_probs, py_utils.NestedMap( new_perturbed_cumulative_log_probs= new_perturbed_cumulative_log_probs, top_k_log_probs=top_k_log_probs, top_k_ids=top_k_ids, )) (bs_results.log_probs, out_states.tmp_states) = tf.cond( encoder_outputs.stochastic_beam_search.enable, PerturbedLogProbs, # No-op. lambda: (bs_results.log_probs, states.tmp_states)) # These states are not updated here but will be updated in # PostBeamSearchStepCallback since doing so requires the knowledge of # the next step IDs. out_states.cumulative_log_probs = states.cumulative_log_probs out_states.perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs return bs_results, out_states
def bucket_fn(num): # Drops record if num[0] is odd. return tf.cond(tf.equal(tf.mod(num[0], 2), 0), lambda: 1, lambda: -tf.to_int32(num[0]))